Skip to content

Commit f7495da

Browse files
stefankandiccloud-fan
authored andcommitted
[SPARK-51501][SQL] Disable ObjectHashAggregate for group by on collated columns
### What changes were proposed in this pull request? Disabling `ObjectHashAggregate` when grouping on columns with collations. ### Why are the changes needed? #45290 added support for sort based aggregation on collated columns and explicitly forbade the use of hash aggregate for collated columns. However, it did not consider the third type of aggregate, the object hash aggregate, which is only used when there are also TypedImperativeAggregate expressions present ([source](https://github.com/apache/spark/blob/f3b081066393e1568c364b6d3bc0bceabd1e7e9f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala#L1204)). That means that if we group by a collated column and also have a TypedImperativeAggregate we will end up using the object has aggregate which can lead to incorrect results like in the example below: ```code CREATE TABLE tbl(c1 STRING COLLATE UTF8_LCASE, c2 INT) USING PARQUET; INSERT INTO tbl VALUES ('HELLO', 1), ('hello', 2), ('HeLlO', 3); SELECT COLLECT_LIST(c2) as list FROM tbl GROUP BY c1; ``` where the result would have three rows with values [1], [2] and [3] instead of one row with value [1, 2, 3]. For this reason we should do the same thing as we did for the regular hash aggregate, make it so that it doesn't support grouping expressions on collated columns. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50267 from stefankandic/fixObjectHashAgg. Authored-by: Stefan Kandic <stefan.kandic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 84b9848 commit f7495da

File tree

4 files changed

+104
-4
lines changed

4 files changed

+104
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,10 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
366366
newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate ||
367367
newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && {
368368
val Seq(newPlanSupportsObjectHashAggregate, cachedPlanSupportsObjectHashAggregate) =
369-
aggregateExpressionsSeq.map(aggregateExpressions =>
370-
Aggregate.supportsObjectHashAggregate(aggregateExpressions))
369+
aggregateExpressionsSeq.zip(groupByExpressionSeq).map {
370+
case (aggregateExpressions, groupByExpressions) =>
371+
Aggregate.supportsObjectHashAggregate(aggregateExpressions, groupByExpressions)
372+
}
371373
newPlanSupportsObjectHashAggregate && cachedPlanSupportsObjectHashAggregate ||
372374
newPlanSupportsObjectHashAggregate == cachedPlanSupportsObjectHashAggregate
373375
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,14 @@ object Aggregate {
11991199
groupingExpression.forall(e => UnsafeRowUtils.isBinaryStable(e.dataType))
12001200
}
12011201

1202-
def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
1202+
def supportsObjectHashAggregate(
1203+
aggregateExpressions: Seq[AggregateExpression],
1204+
groupingExpressions: Seq[Expression]): Boolean = {
1205+
// We should not use hash aggregation on binary unstable types.
1206+
if (groupingExpressions.exists(e => !UnsafeRowUtils.isBinaryStable(e.dataType))) {
1207+
return false
1208+
}
1209+
12031210
aggregateExpressions.map(_.aggregateFunction).exists {
12041211
case _: TypedImperativeAggregate[_] => true
12051212
case _ => false

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ object AggUtils {
9494
child = child)
9595
} else {
9696
val objectHashEnabled = child.conf.useObjectHashAggregation
97-
val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions)
97+
val useObjectHash = Aggregate.supportsObjectHashAggregate(
98+
aggregateExpressions, groupingExpressions)
9899

99100
if (forceObjHashAggregate || (objectHashEnabled && useObjectHash && !forceSortAggregate)) {
100101
ObjectHashAggregateExec(
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.collation
19+
20+
import org.apache.spark.sql.{QueryTest, Row}
21+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
22+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
23+
import org.apache.spark.sql.test.SharedSparkSession
24+
25+
class CollationAggregationSuite
26+
extends QueryTest
27+
with SharedSparkSession
28+
with AdaptiveSparkPlanHelper {
29+
30+
test("group by collated column doesn't work with obj hash aggregate") {
31+
val tblName = "grp_by_tbl"
32+
withTable(tblName) {
33+
sql(s"CREATE TABLE $tblName (c1 STRING COLLATE UTF8_LCASE, c2 INT) USING PARQUET")
34+
sql(s"INSERT INTO $tblName VALUES ('hello', 1), ('HELLO', 2), ('HeLlO', 3)")
35+
36+
// Result is correct without forcing object hash aggregate.
37+
checkAnswer(
38+
sql(s"SELECT COUNT(*) FROM $tblName GROUP BY c1"),
39+
Seq(Row(3)))
40+
41+
withSQLConf("spark.sql.test.forceApplyObjectHashAggregate" -> true.toString) {
42+
checkAnswer(
43+
sql(s"SELECT COUNT(*) FROM $tblName GROUP BY c1"),
44+
Seq(Row(1), Row(1), Row(1)))
45+
46+
checkAnswer(
47+
sql(s"SELECT COLLECT_LIST(c2) AS c3 FROM $tblName GROUP BY c1 ORDER BY c3"),
48+
Seq(Row(Seq(1)), Row(Seq(2)), Row(Seq(3))))
49+
}
50+
}
51+
}
52+
53+
test("imperative aggregate fn does not use objectHashAggregate when group by collated column") {
54+
val tblName = "imp_agg"
55+
Seq(true, false).foreach { useObjHashAgg =>
56+
withTable(tblName) {
57+
withSQLConf("spark.sql.execution.useObjectHashAggregateExec" -> useObjHashAgg.toString) {
58+
sql(
59+
s"""
60+
|CREATE TABLE $tblName (
61+
| c1 STRING COLLATE UTF8_LCASE,
62+
| c2 INT
63+
|) USING PARQUET
64+
|""".stripMargin)
65+
sql(s"INSERT INTO $tblName VALUES ('HELLO', 1), ('hello', 2), ('HeLlO', 3)")
66+
67+
val df = sql(s"SELECT COLLECT_LIST(c2) as list FROM $tblName GROUP BY c1")
68+
val executedPlan = df.queryExecution.executedPlan
69+
70+
// Plan should not have any hash aggregate nodes.
71+
collectFirst(executedPlan) {
72+
case _: ObjectHashAggregateExec => fail("ObjectHashAggregateExec should not be used.")
73+
case _: HashAggregateExec => fail("HashAggregateExec should not be used.")
74+
}
75+
76+
// Plan should have a [[SortAggregateExec]] node.
77+
assert(collectFirst(executedPlan) {
78+
case _: SortAggregateExec => true
79+
}.nonEmpty)
80+
81+
checkAnswer(
82+
// Sort the values to get deterministic output.
83+
df.selectExpr("array_sort(list)"),
84+
Seq(Row(Seq(1, 2, 3)))
85+
)
86+
}
87+
}
88+
}
89+
}
90+
}

0 commit comments

Comments
 (0)