Skip to content

Commit 796c7d1

Browse files
Jefffreyalamb
andauthored
feat: support f16 in coercion logic (#18944)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18943 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Originally: > As pointed out by @martin-g, even though we plan to remove `NUMERICS` (see #18092) we should probably add f16 first so we don't conflate adding new functionality with refactoring changes. Updated: > #19727 removes `NUMERICS` for us, which surfaced a bug where f16 wasn't being coerced to f64. Turns out we didn't have f16 support in the logic calculating the potential coercions. Fixing this so f16 input to a signature expected f64 is now allowed and coerced. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Support coercion of f16 to f64 as specified by signature. Add tests for regr, percentile & covar functions. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 613f87d commit 796c7d1

5 files changed

Lines changed: 42 additions & 3 deletions

File tree

datafusion/expr-common/src/signature.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,7 @@ mod tests {
15851585
vec![DataType::UInt16, DataType::UInt16],
15861586
vec![DataType::UInt32, DataType::UInt32],
15871587
vec![DataType::UInt64, DataType::UInt64],
1588+
vec![DataType::Float16, DataType::Float16],
15881589
vec![DataType::Float32, DataType::Float32],
15891590
vec![DataType::Float64, DataType::Float64]
15901591
]

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub static NUMERICS: &[DataType] = &[
4242
DataType::UInt16,
4343
DataType::UInt32,
4444
DataType::UInt64,
45+
DataType::Float16,
4546
DataType::Float32,
4647
DataType::Float64,
4748
];

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,10 +852,13 @@ fn coerced_from<'a>(
852852
(UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
853853
(UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
854854
(UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
855+
(Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => {
856+
Some(type_into.clone())
857+
}
855858
(
856859
Float32,
857860
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
858-
| Float32,
861+
| Float16 | Float32,
859862
) => Some(type_into.clone()),
860863
(
861864
Float64,
@@ -868,6 +871,7 @@ fn coerced_from<'a>(
868871
| UInt16
869872
| UInt32
870873
| UInt64
874+
| Float16
871875
| Float32
872876
| Float64
873877
| Decimal32(_, _)

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1892,7 +1892,7 @@ mod test {
18921892
.err()
18931893
.unwrap()
18941894
.strip_backtrace();
1895-
assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1895+
assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64]) failed"));
18961896
Ok(())
18971897
}
18981898

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,16 @@ SELECT covar(c2, c12) FROM aggregate_test_100
571571
----
572572
-0.079969012479
573573

574+
query R
575+
SELECT covar_pop(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100
576+
----
577+
-0.079163311005
578+
579+
query R
580+
SELECT covar(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100
581+
----
582+
-0.079962940409
583+
574584
# single_row_query_covar_1
575585
query R
576586
select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq
@@ -1313,6 +1323,24 @@ select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median
13131323
----
13141324
2.75 Float16
13151325

1326+
# This shouldn't be NaN, see:
1327+
# https://github.com/apache/datafusion/issues/18945
1328+
query RT
1329+
select
1330+
percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')),
1331+
arrow_typeof(percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')))
1332+
from median_table;
1333+
----
1334+
NaN Float16
1335+
1336+
query RT
1337+
select
1338+
approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')),
1339+
arrow_typeof(approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')))
1340+
from median_table;
1341+
----
1342+
2.75 Float16
1343+
13161344
query ?T
13171345
select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table;
13181346
----
@@ -6719,7 +6747,12 @@ from aggregate_test_100;
67196747
----
67206748
0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695
67216749

6722-
6750+
query R
6751+
select
6752+
regr_slope(arrow_cast(c12, 'Float16'), arrow_cast(c11, 'Float16'))
6753+
from aggregate_test_100;
6754+
----
6755+
0.051477733249
67236756

67246757
# regr_*() functions ignore NULLs
67256758
query RRIRRRRRR

0 commit comments

Comments
 (0)