Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
- [ ] sequence
- [ ] shuffle
- [ ] slice
- [ ] sort_array
- [x] sort_array

### bitwise_funcs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArrayMin] -> CometArrayMin,
classOf[ArrayRemove] -> CometArrayRemove,
classOf[ArrayRepeat] -> CometArrayRepeat,
classOf[SortArray] -> CometSortArray,
classOf[ArraysOverlap] -> CometArraysOverlap,
classOf[ArrayUnion] -> CometArrayUnion,
classOf[CreateArray] -> CometCreateArray,
Expand Down
82 changes: 81 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ package org.apache.comet.serde

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometExprShim
Expand Down Expand Up @@ -200,6 +201,85 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
}
}

object CometSortArray extends CometExpressionSerde[SortArray] {
private def containsFloatingPoint(dt: DataType): Boolean = {
Comment thread
grorge123 marked this conversation as resolved.
dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}

private def canRank(dt: DataType, nestedInArray: Boolean = false): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: NullType =>
true
case _: BooleanType | _: BinaryType | _: StringType =>
true
case ArrayType(elementType, _) =>
canRank(elementType, nestedInArray = true)
case StructType(fields) if !nestedInArray =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment explaining why there is a restriction around structs in arrays?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have added it. Besides, I found nulltype has a similar problem, I have fixed it.

fields.forall(f => canRank(f.dataType))
case _ =>
false
}
}

override def getSupportLevel(expr: SortArray): SupportLevel = {
val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType

if (!canRank(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsFloatingPoint(elementType)) {
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
}

override def convert(
expr: SortArray,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding)
val sortDirectionExprProto = expr.ascendingOrder match {
Comment thread
grorge123 marked this conversation as resolved.
Outdated
case Literal(value: Boolean, BooleanType) =>
Comment thread
grorge123 marked this conversation as resolved.
val direction = if (value) "ASC" else "DESC"
exprToProtoInternal(Literal(direction), inputs, binding)
case other =>
withInfo(expr, s"ascendingOrder must be a boolean literal: $other")
None
}
val nullOrderingExprProto = expr.ascendingOrder match {
case Literal(value: Boolean, BooleanType) =>
val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST"
exprToProtoInternal(Literal(nullOrdering), inputs, binding)
case _ => None
}

val sortArrayScalarExpr =
scalarFunctionExprToProto(
"array_sort",
arrayExprProto,
sortDirectionExprProto,
nullOrderingExprProto)
optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*)
}
}

object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] {

override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None)
Expand Down
199 changes: 199 additions & 0 deletions spark/src/test/resources/sql-tests/expressions/array/sort_array.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- ConfigMatrix: parquet.enable.dictionary=false,true

statement
Comment thread
grorge123 marked this conversation as resolved.
CREATE TABLE test_sort_array_int(arr array<int>) USING parquet

statement
INSERT INTO test_sort_array_int VALUES
(array(3, 1, 4, 1, 5)),
(array(3, NULL, 1, NULL, 2)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_int

query
SELECT sort_array(arr, true) FROM test_sort_array_int

query
SELECT sort_array(arr, false) FROM test_sort_array_int
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 This covers both cases mentioned in Spark's comment:

Null elements will be placed at the beginning of the returned array in ascending order or at the end of the returned array in descending order.


statement
CREATE TABLE test_sort_array_string(arr array<string>) USING parquet

statement
INSERT INTO test_sort_array_string VALUES
(array('d', 'c', 'b', 'a')),
(array('b', NULL, 'a')),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_string

query
SELECT sort_array(arr, true) FROM test_sort_array_string

query
SELECT sort_array(arr, false) FROM test_sort_array_string

statement
CREATE TABLE test_sort_array_float(arr array<double>) USING parquet
Comment thread
grorge123 marked this conversation as resolved.
Outdated

statement
INSERT INTO test_sort_array_float VALUES
(array(CAST('NaN' AS DOUBLE), 3.0, 1.0, NULL, -0.0, 0.0)),
(array(CAST('NaN' AS DOUBLE), CAST('NaN' AS DOUBLE), -5.0, 2.0)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_float

query
SELECT sort_array(arr, true) FROM test_sort_array_float

query
SELECT sort_array(arr, false) FROM test_sort_array_float

statement
CREATE TABLE test_sort_array_decimal(arr array<decimal(10, 0)>) USING parquet

statement
INSERT INTO test_sort_array_decimal VALUES
(array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0)))),
Comment thread
grorge123 marked this conversation as resolved.
Outdated
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_decimal

query
SELECT sort_array(arr, true) FROM test_sort_array_decimal

query
SELECT sort_array(arr, false) FROM test_sort_array_decimal

statement
CREATE TABLE test_sort_array_boolean(arr array<boolean>) USING parquet

statement
INSERT INTO test_sort_array_boolean VALUES
(array(true, false, true, false)),
(array(true, false, true, NULL, false)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_boolean

query
SELECT sort_array(arr, true) FROM test_sort_array_boolean

query
SELECT sort_array(arr, false) FROM test_sort_array_boolean

statement
CREATE TABLE test_sort_array_struct(arr array<struct<a:int,b:string>>) USING parquet

statement
INSERT INTO test_sort_array_struct VALUES
(array(
named_struct('a', 2, 'b', 'b'),
named_struct('a', 1, 'b', 'c'),
named_struct('a', 1, 'b', 'a'))),
(array(
named_struct('a', 2, 'b', NULL),
named_struct('a', 1, 'b', 'z'),
named_struct('a', 1, 'b', NULL))),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_struct

query
SELECT sort_array(arr, false) FROM test_sort_array_struct

statement
CREATE TABLE test_sort_array_nested(arr array<array<int>>) USING parquet

statement
INSERT INTO test_sort_array_nested VALUES
(array(array(2, 3), array(1), array(2, 1))),
(array(array(1, NULL), array(1), NULL)),
(array()),
(NULL)

query
SELECT sort_array(arr) FROM test_sort_array_nested

query
SELECT sort_array(arr, false) FROM test_sort_array_nested

statement
CREATE TABLE test_sort_array_nested_struct(arr array<array<struct<a:int>>>) USING parquet

statement
INSERT INTO test_sort_array_nested_struct VALUES
(array(
array(named_struct('a', 2)),
array(named_struct('a', 1)))),
(array()),
(NULL)

query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType)
SELECT sort_array(arr) FROM test_sort_array_nested_struct

query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType)
SELECT sort_array(arr, false) FROM test_sort_array_nested_struct

-- literal arguments
query
SELECT
sort_array(array(3, 1, 4, 1, 5)),
sort_array(array(3, 1, 4, 1, 5), true),
sort_array(array(3, NULL, 1, NULL, 2)),
sort_array(array(3, NULL, 1, NULL, 2), false),
sort_array(array(CAST('NaN' AS DOUBLE), 1.0, NULL, -0.0, 0.0)),
sort_array(array(CAST('NaN' AS DOUBLE), 1.0, NULL, -0.0, 0.0), false),
sort_array(array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0)))),
sort_array(
array(CAST(100 AS DECIMAL(10, 0)), CAST(10 AS DECIMAL(10, 0))),
false),
sort_array(array(true, false, true, false)),
sort_array(array(true, false, true, NULL, false)),
sort_array(array(true, false, true, NULL, false), false),
sort_array(
array(
named_struct('a', 2, 'b', 'b'),
named_struct('a', 1, 'b', 'c'),
named_struct('a', 1, 'b', 'a'))),
sort_array(array(array(2, 3), array(1), array(2, 1))),
sort_array(array(array(1, NULL), array(1), NULL)),
sort_array(array(NULL, NULL)),
sort_array(cast(NULL as array<int>))

query expect_fallback(Sort on array element type ArrayType(StructType(StructField(a,IntegerType)
SELECT sort_array(
array(
array(named_struct('a', 2)),
array(named_struct('a', 1))))
Loading