Skip to content

Commit 1010dab

Browse files
mgaido91HyukjinKwon
authored andcommitted
[SPARK-55969][SQL] regr_r2 should treat first param as dependent variable
### What changes were proposed in this pull request? The `regr_r2` function currently is swapping the two parameters when processing the results. This causes issues only in few special cases, as otherwise the result does not depend from the order of its parameters. ### Why are the changes needed? With special cases, the result is wrong. ### Does this PR introduce _any_ user-facing change? Yes, fixes the result of regr_r2 in the special cases in which either the dependent or independent variable has always the same value. ### How was this patch tested? Added UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54901 from mgaido91/SPARK-55969. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 81e8678 commit 1010dab

File tree

4 files changed

+50
-2
lines changed

4 files changed

+50
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ case class RegrR2(y: Expression, x: Expression) extends PearsonCorrelation(y, x,
147147
override def prettyName: String = "regr_r2"
148148
override val evaluateExpression: Expression = {
149149
val corr = ck / sqrt(xMk * yMk)
150-
If(xMk === 0.0, Literal.create(null, DoubleType),
151-
If(yMk === 0.0, Literal.create(1.0, DoubleType), corr * corr))
150+
// In PearsonCorrelation, x and y are swapped, so here xMk refers to the dependent variable
151+
// and yMk to the independent variable
152+
If(yMk === 0.0, Literal.create(null, DoubleType),
153+
If(xMk === 0.0, Literal.create(1.0, DoubleType), corr * corr))
152154
}
153155
override protected def withNewChildrenInternal(
154156
newLeft: Expression, newRight: Expression): RegrR2 =

sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,29 @@ Aggregate [k#x], [k#x, regr_intercept(cast(y#x as double), cast(x#x as double))
407407
+- Project [k#x, y#x, x#x]
408408
+- SubqueryAlias testRegression
409409
+- LocalRelation [k#x, y#x, x#x]
410+
411+
412+
-- !query
413+
SELECT regr_r2(k, x) FROM testRegression where k=2
414+
-- !query analysis
415+
Aggregate [regr_r2(cast(k#x as double), cast(x#x as double)) AS regr_r2(k, x)#x]
416+
+- Filter (k#x = 2)
417+
+- SubqueryAlias testregression
418+
+- View (`testRegression`, [k#x, y#x, x#x])
419+
+- Project [cast(k#x as int) AS k#x, cast(y#x as int) AS y#x, cast(x#x as int) AS x#x]
420+
+- Project [k#x, y#x, x#x]
421+
+- SubqueryAlias testRegression
422+
+- LocalRelation [k#x, y#x, x#x]
423+
424+
425+
-- !query
426+
SELECT regr_r2(y, k) FROM testRegression where k=2
427+
-- !query analysis
428+
Aggregate [regr_r2(cast(y#x as double), cast(k#x as double)) AS regr_r2(y, k)#x]
429+
+- Filter (k#x = 2)
430+
+- SubqueryAlias testregression
431+
+- View (`testRegression`, [k#x, y#x, x#x])
432+
+- Project [cast(k#x as int) AS k#x, cast(y#x as int) AS y#x, cast(x#x as int) AS x#x]
433+
+- Project [k#x, y#x, x#x]
434+
+- SubqueryAlias testRegression
435+
+- LocalRelation [k#x, y#x, x#x]

sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,7 @@ SELECT regr_intercept(y, x) FROM testRegression;
5050
SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL;
5151
SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k;
5252
SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k;
53+
54+
-- SPARK-55969: regr_r2 should treat first param as dependent variable
55+
SELECT regr_r2(k, x) FROM testRegression where k=2;
56+
SELECT regr_r2(y, k) FROM testRegression where k=2;

sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,19 @@ SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS
274274
struct<k:int,regr_intercept(y, x):double>
275275
-- !query output
276276
2 1.1547344110854496
277+
278+
279+
-- !query
280+
SELECT regr_r2(k, x) FROM testRegression where k=2
281+
-- !query schema
282+
struct<regr_r2(k, x):double>
283+
-- !query output
284+
1.0
285+
286+
287+
-- !query
288+
SELECT regr_r2(y, k) FROM testRegression where k=2
289+
-- !query schema
290+
struct<regr_r2(y, k):double>
291+
-- !query output
292+
NULL

0 commit comments

Comments
 (0)