Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
- [ ] sequence
- [ ] shuffle
- [ ] slice
- [ ] sort_array
- [x] sort_array

### bitwise_funcs

Expand Down
37 changes: 17 additions & 20 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[ArrayMin] -> CometArrayMin,
classOf[ArrayRemove] -> CometArrayRemove,
classOf[ArrayRepeat] -> CometArrayRepeat,
classOf[SortArray] -> CometSortArray,
classOf[ArraysOverlap] -> CometArraysOverlap,
classOf[ArrayUnion] -> CometArrayUnion,
classOf[CreateArray] -> CometCreateArray,
Expand Down Expand Up @@ -778,30 +779,26 @@ object QueryPlanSerde extends Logging with CometExprShim {
* TODO: Include SparkSQL's [[YearMonthIntervalType]] and [[DayTimeIntervalType]]
*/
// scalastyle:on
def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
def canRank(dt: DataType): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: BooleanType | _: BinaryType | _: StringType => true
case _ => false
}
def supportedScalarSortElementType(dt: DataType): Boolean = {
dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
Comment thread
grorge123 marked this conversation as resolved.
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: BooleanType | _: BinaryType | _: StringType =>
true
case _ =>
false
}
}

def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
if (sortOrder.length == 1) {
val canSort = sortOrder.head.dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: DecimalType =>
true
case _: DateType | _: TimestampType | _: TimestampNTZType =>
true
case _: BooleanType | _: BinaryType | _: StringType => true
case ArrayType(elementType, _) => canRank(elementType)
case MapType(_, valueType, _) => canRank(valueType)
case _ => false
case ArrayType(elementType, _) => supportedScalarSortElementType(elementType)
case MapType(_, valueType, _) => supportedScalarSortElementType(valueType)
case _ => supportedScalarSortElementType(sortOrder.head.dataType)
}
if (!canSort) {
withInfo(op, s"Sort on single column of type ${sortOrder.head.dataType} is not supported")
Expand Down
77 changes: 76 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ package org.apache.comet.serde

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometExprShim
Expand Down Expand Up @@ -200,6 +201,80 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
}
}

object CometSortArray extends CometExpressionSerde[SortArray] {
private def containsFloatingPoint(dt: DataType): Boolean = {
Comment thread
grorge123 marked this conversation as resolved.
dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}

private def supportedSortArrayElementType(
dt: DataType,
nestedInArray: Boolean = false): Boolean = {
dt match {
// DataFusion's array_sort compares nested arrays through Arrow's rank kernel.
// That kernel does not support Struct or Null child values,
// so array<array<struct<...>>> and array<array<null>> would fail at runtime.
case _: NullType if !nestedInArray =>
true
case ArrayType(elementType, _) =>
supportedSortArrayElementType(elementType, nestedInArray = true)
case StructType(fields) if !nestedInArray =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment explaining why there is a restriction around structs in arrays?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have added it. Besides, I found nulltype has a similar problem, I have fixed it.

fields.forall(f => supportedSortArrayElementType(f.dataType))
case _ =>
supportedScalarSortElementType(dt)
}
}

override def getSupportLevel(expr: SortArray): SupportLevel = {
val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType

if (!supportedSortArrayElementType(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsFloatingPoint(elementType)) {
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
}

override def convert(
expr: SortArray,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding)
val (sortDirectionExprProto, nullOrderingExprProto) = expr.ascendingOrder match {
case Literal(value: Boolean, BooleanType) =>
Comment thread
grorge123 marked this conversation as resolved.
val direction = if (value) "ASC" else "DESC"
val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST"
(
exprToProtoInternal(Literal(direction), inputs, binding),
exprToProtoInternal(Literal(nullOrdering), inputs, binding))
case other =>
withInfo(expr, s"ascendingOrder must be a boolean literal: $other")
(None, None)
}

val sortArrayScalarExpr =
scalarFunctionExprToProto(
"array_sort",
arrayExprProto,
sortDirectionExprProto,
nullOrderingExprProto)
optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*)
}
}

object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] {

override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None)
Expand Down
Loading