Skip to content

Commit a23cd04

Browse files
authored
Merge branch 'main' into df54
2 parents 77ac0db + 5efd972 commit a23cd04

19 files changed

Lines changed: 636 additions & 323 deletions

File tree

docs/source/contributor-guide/adding_a_new_operator.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,14 @@ For operators that run in the JVM:
553553
Example pattern from `CometExecRule.scala`:
554554

555555
```scala
556-
case s: ShuffleExchangeExec if nativeShuffleSupported(s) =>
557-
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
556+
case s: ShuffleExchangeExec =>
557+
CometShuffleExchangeExec.shuffleSupported(s) match {
558+
case Some(CometNativeShuffle) =>
559+
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
560+
case Some(CometColumnarShuffle) =>
561+
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
562+
case None => s
563+
}
558564
```
559565

560566
## Common Patterns and Helpers

docs/source/user-guide/latest/compatibility.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ Expressions that are not 100% Spark-compatible will fall back to Spark by defaul
5858
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See
5959
the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting.
6060

61+
### Aggregate Expressions
62+
63+
- **CollectSet**: Comet deduplicates NaN values (treats `NaN == NaN`) while Spark treats each NaN as a distinct value.
64+
When `spark.comet.exec.strictFloatingPoint=true`, `collect_set` on floating-point types falls back to Spark unless
65+
`spark.comet.expression.CollectSet.allowIncompatible=true` is set.
66+
6167
### Array Expressions
6268

6369
- **ArraysOverlap**: Inconsistent behavior when arrays contain NULL values.

docs/source/user-guide/latest/expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
203203
| BitXorAgg | | Yes | |
204204
| BoolAnd | `bool_and` | Yes | |
205205
| BoolOr | `bool_or` | Yes | |
206+
| CollectSet | | No | NaN dedup differs from Spark. See compatibility guide. |
206207
| Corr | | Yes | |
207208
| Count | | Yes | |
208209
| CovPopulation | | Yes | |

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
- [x] bool_and
3434
- [x] bool_or
3535
- [ ] collect_list
36-
- [ ] collect_set
36+
- [x] collect_set
3737
- [ ] corr
3838
- [x] count
3939
- [x] count_if

native/core/src/execution/planner.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ use datafusion_comet_spark_expr::{
7070
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
7171
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SumInteger, ToCsv,
7272
};
73+
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
7374
use iceberg::expr::Bind;
7475

7576
use crate::execution::operators::ExecutionError::GeneralError;
@@ -2259,6 +2260,11 @@ impl PhysicalPlanner {
22592260
));
22602261
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
22612262
}
2263+
AggExprStruct::CollectSet(expr) => {
2264+
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2265+
let func = AggregateUDF::new_from_impl(SparkCollectSet::new());
2266+
Self::create_aggr_func_expr("collect_set", schema, vec![child], func)
2267+
}
22622268
}
22632269
}
22642270

native/proto/src/proto/expr.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ message AggExpr {
140140
Stddev stddev = 14;
141141
Correlation correlation = 15;
142142
BloomFilterAgg bloomFilterAgg = 16;
143+
CollectSet collectSet = 17;
143144
}
144145

145146
// Optional filter expression for SQL FILTER (WHERE ...) clause.
@@ -248,6 +249,11 @@ message BloomFilterAgg {
248249
DataType datatype = 4;
249250
}
250251

252+
message CollectSet {
253+
Expr child = 1;
254+
DataType datatype = 2;
255+
}
256+
251257
enum EvalMode {
252258
LEGACY = 0;
253259
TRY = 1;

spark/src/main/scala/org/apache/comet/CometFallback.scala

Lines changed: 0 additions & 67 deletions
This file was deleted.

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,29 @@ object CometSparkSessionExtensions extends Logging {
200200
}
201201

202202
/**
203-
* Attaches explain information to a TreeNode, rolling up the corresponding information tags
204-
* from any child nodes. For now, we are using this to attach the reasons why certain Spark
205-
* operators or expressions are disabled.
203+
* Record a fallback reason on a `TreeNode` (a Spark operator or expression) explaining why
204+
* Comet cannot accelerate it. Reasons recorded here are surfaced in extended explain output
205+
* (see `ExtendedExplainInfo`) and, when `COMET_LOG_FALLBACK_REASONS` is enabled, logged as
206+
* warnings. The reasons are also rolled up from child nodes so that the operator that remains
207+
* in the Spark plan carries the reasons from its converted-away subtree.
208+
*
209+
* Call this in any code path where Comet decides not to convert a given node - serde `convert`
210+
* methods returning `None`, unsupported data types, disabled configs, etc. Do not use this for
211+
* informational messages that are not fallback reasons: anything tagged here is treated by the
212+
* rules as a signal that the node falls back to Spark.
206213
*
207214
* @param node
208-
* The node to attach the explain information to. Typically a SparkPlan
215+
* The Spark operator or expression that is falling back to Spark.
209216
* @param info
210-
* Information text. Optional, may be null or empty. If not provided, then only information
211-
* from child nodes will be included.
217+
* The fallback reason. Optional, may be null or empty - pass empty only when the call is used
218+
* purely to roll up reasons from `exprs`.
212219
* @param exprs
213-
* Child nodes. Information attached in these nodes will be be included in the information
214-
* attached to @node
220+
* Child nodes whose own fallback reasons should be rolled up into `node`. Pass the
221+
* sub-expressions or child operators whose failure caused `node` to fall back.
215222
* @tparam T
216-
* The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression
223+
* The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`.
217224
* @return
218-
* The node with information (if any) attached
225+
* `node` with fallback reasons attached (as a side effect on its tag map).
219226
*/
220227
def withInfo[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = {
221228
// support existing approach of passing in multiple infos in a newline-delimited string
@@ -228,22 +235,24 @@ object CometSparkSessionExtensions extends Logging {
228235
}
229236

230237
/**
231-
* Attaches explain information to a TreeNode, rolling up the corresponding information tags
232-
* from any child nodes. For now, we are using this to attach the reasons why certain Spark
233-
* operators or expressions are disabled.
238+
* Record one or more fallback reasons on a `TreeNode` and roll up reasons from any child nodes.
239+
* This is the set-valued form of [[withInfo]]; see that overload for the full contract.
240+
*
241+
* Reasons are accumulated (never overwritten) on the node's `EXTENSION_INFO` tag and are
242+
* surfaced in extended explain output. When `COMET_LOG_FALLBACK_REASONS` is enabled, each new
243+
* reason is also emitted as a warning.
234244
*
235245
* @param node
236-
* The node to attach the explain information to. Typically a SparkPlan
246+
* The Spark operator or expression that is falling back to Spark.
237247
* @param info
238-
* Information text. May contain zero or more strings. If not provided, then only information
239-
* from child nodes will be included.
248+
* The fallback reasons for this node. May be empty when the call is used purely to roll up
249+
* child reasons.
240250
* @param exprs
241-
* Child nodes. Information attached in these nodes will be be included in the information
242-
* attached to @node
251+
* Child nodes whose own fallback reasons should be rolled up into `node`.
243252
* @tparam T
244-
* The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression
253+
* The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`.
245254
* @return
246-
* The node with information (if any) attached
255+
* `node` with fallback reasons attached (as a side effect on its tag map).
247256
*/
248257
def withInfos[T <: TreeNode[_]](node: T, info: Set[String], exprs: T*): T = {
249258
if (CometConf.COMET_LOG_FALLBACK_REASONS.get()) {
@@ -259,25 +268,27 @@ object CometSparkSessionExtensions extends Logging {
259268
}
260269

261270
/**
262-
* Attaches explain information to a TreeNode, rolling up the corresponding information tags
263-
* from any child nodes
271+
* Roll up fallback reasons from `exprs` onto `node` without adding a new reason of its own. Use
272+
* this when a parent operator is itself falling back and wants to preserve the reasons recorded
273+
* on its child expressions/operators so they appear together in explain output.
264274
*
265275
* @param node
266-
* The node to attach the explain information to. Typically a SparkPlan
276+
* The parent operator or expression falling back to Spark.
267277
* @param exprs
268-
* Child nodes. Information attached in these nodes will be be included in the information
269-
* attached to @node
278+
* Child nodes whose fallback reasons should be aggregated onto `node`.
270279
* @tparam T
271-
* The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression
280+
* The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`.
272281
* @return
273-
* The node with information (if any) attached
282+
* `node` with the rolled-up reasons attached (as a side effect on its tag map).
274283
*/
275284
def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = {
276285
withInfos(node, Set.empty, exprs: _*)
277286
}
278287

279288
/**
280-
* Checks whether a TreeNode has any explain information attached
289+
* True if any fallback reason has been recorded on `node` (via [[withInfo]] / [[withInfos]]).
290+
* Callers that need to short-circuit when a prior rule pass has already decided a node falls
291+
* back can use this as the sticky signal.
281292
*/
282293
def hasExplainInfo(node: TreeNode[_]): Boolean = {
283294
node.getTagValue(CometExplainInfo.EXTENSION_INFO).exists(_.nonEmpty)

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,18 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
9898
private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()
9999

100100
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
101-
plan.transformUp {
102-
case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) =>
103-
// Switch to use Decimal128 regardless of precision, since Arrow native execution
104-
// doesn't support Decimal32 and Decimal64 yet.
105-
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
106-
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
107-
108-
case s: ShuffleExchangeExec if CometShuffleExchangeExec.columnarShuffleSupported(s) =>
109-
// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
110-
// (if configured)
111-
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
101+
plan.transformUp { case s: ShuffleExchangeExec =>
102+
CometShuffleExchangeExec.shuffleSupported(s) match {
103+
case Some(CometNativeShuffle) =>
104+
// Switch to use Decimal128 regardless of precision, since Arrow native execution
105+
// doesn't support Decimal32 and Decimal64 yet.
106+
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
107+
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
108+
case Some(CometColumnarShuffle) =>
109+
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
110+
case None =>
111+
s
112+
}
112113
}
113114
}
114115

spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.apache.comet.serde
2121

2222
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Descending, NullsFirst, NullsLast, SortOrder}
23-
import org.apache.spark.sql.types._
2423

2524
import org.apache.comet.CometConf
2625
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -30,19 +29,8 @@ object CometSortOrder extends CometExpressionSerde[SortOrder] {
3029

3130
override def getSupportLevel(expr: SortOrder): SupportLevel = {
3231

33-
def containsFloatingPoint(dt: DataType): Boolean = {
34-
dt match {
35-
case DataTypes.FloatType | DataTypes.DoubleType => true
36-
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
37-
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
38-
case MapType(keyType, valueType, _) =>
39-
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
40-
case _ => false
41-
}
42-
}
43-
4432
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
45-
containsFloatingPoint(expr.child.dataType)) {
33+
SupportLevel.containsFloatingPoint(expr.child.dataType)) {
4634
// https://github.com/apache/datafusion-comet/issues/2626
4735
Incompatible(
4836
Some(

0 commit comments

Comments
 (0)