Skip to content

Commit 4175544

Browse files
fangchenlizhengruifeng
authored andcommitted
[SPARK-54941][PYTHON][TESTS] Add tests for pa.scalar type coercion for numerical types
### What changes were proposed in this pull request? Add tests for pa.scalar type coercion. ### Why are the changes needed? We want to monitor changes in PyArrow's behavior. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Opus 4.5 Closes #53762 from fangchenli/pa-scalar-coercion-tests. Authored-by: Fangchen Li <fangchen.li@outlook.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent a281331 commit 4175544

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def __hash__(self):
492492
# unittests for upstream projects
493493
"pyspark.tests.upstream.pyarrow.test_pyarrow_array_type_inference",
494494
"pyspark.tests.upstream.pyarrow.test_pyarrow_ignore_timezone",
495+
"pyspark.tests.upstream.pyarrow.test_pyarrow_scalar_type_coercion",
495496
"pyspark.tests.upstream.pyarrow.test_pyarrow_scalar_type_inference",
496497
"pyspark.tests.upstream.pyarrow.test_pyarrow_type_coercion",
497498
],
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Test pa.scalar type coercion behavior when creating scalars with explicit type parameter.
20+
21+
This test monitors the behavior of PyArrow's type coercion to ensure PySpark's assumptions
22+
about PyArrow behavior remain valid across versions.
23+
24+
Test categories:
25+
1. Null coercion - None to numeric types
26+
2. Numeric types - int, float, decimal coercion and boundaries
27+
28+
The helper method pattern is adapted from PR #53721 (pa.array type coercion tests).
29+
"""
30+
31+
import math
32+
import unittest
33+
from decimal import Decimal
34+
35+
from pyspark.testing.utils import (
36+
have_pyarrow,
37+
pyarrow_requirement_message,
38+
)
39+
40+
41+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
42+
class PyArrowScalarTypeCoercionTests(unittest.TestCase):
43+
"""Test PyArrow's type coercion behavior for pa.scalar with explicit type parameter."""
44+
45+
# =========================================================================
46+
# Helper methods
47+
# =========================================================================
48+
49+
def _run_coercion_tests(self, cases):
50+
"""Run coercion tests: (value, target_type)."""
51+
import pyarrow as pa
52+
53+
for value, target_type in cases:
54+
scalar = pa.scalar(value, type=target_type)
55+
self.assertEqual(scalar.type, target_type)
56+
57+
def _run_coercion_tests_with_values(self, cases):
58+
"""Run coercion tests with value verification: (value, target_type, expected)."""
59+
import pyarrow as pa
60+
61+
for value, target_type, expected in cases:
62+
scalar = pa.scalar(value, type=target_type)
63+
self.assertEqual(scalar.type, target_type)
64+
self.assertEqual(scalar.as_py(), expected)
65+
66+
def _run_error_tests(self, cases, error_type):
67+
"""Run tests expecting errors: (value, target_type)."""
68+
import pyarrow as pa
69+
70+
for value, target_type in cases:
71+
with self.assertRaises(error_type):
72+
pa.scalar(value, type=target_type)
73+
74+
# =========================================================================
75+
# SECTION 1: Null Coercion
76+
# =========================================================================
77+
78+
def test_null_coercion(self):
79+
"""Test that None can be coerced to numeric types as a null scalar."""
80+
import pyarrow as pa
81+
82+
target_types = [
83+
pa.int8(),
84+
pa.int16(),
85+
pa.int32(),
86+
pa.int64(),
87+
pa.uint8(),
88+
pa.uint16(),
89+
pa.uint32(),
90+
pa.uint64(),
91+
pa.float32(),
92+
pa.float64(),
93+
pa.decimal128(20, 2),
94+
]
95+
96+
for target_type in target_types:
97+
scalar = pa.scalar(None, type=target_type)
98+
self.assertEqual(scalar.type, target_type)
99+
self.assertFalse(scalar.is_valid)
100+
self.assertIsNone(scalar.as_py())
101+
102+
# =========================================================================
103+
# SECTION 2: Numeric Type Coercion
104+
# =========================================================================
105+
106+
def test_numeric_coercion(self):
107+
"""Test numeric type coercion: int, float, decimal."""
108+
import pyarrow as pa
109+
110+
# ---- Integer to Integer ----
111+
int_to_int_cases = [
112+
(42, pa.int8(), 42),
113+
(42, pa.int16(), 42),
114+
(42, pa.int32(), 42),
115+
(42, pa.int64(), 42),
116+
(127, pa.int8(), 127), # max int8
117+
(-128, pa.int8(), -128), # min int8
118+
(0, pa.uint8(), 0),
119+
(255, pa.uint8(), 255), # max uint8
120+
(2**62, pa.int64(), 2**62),
121+
]
122+
self._run_coercion_tests_with_values(int_to_int_cases)
123+
124+
# ---- Integer to Float ----
125+
int_to_float_cases = [
126+
(42, pa.float32(), 42.0),
127+
(42, pa.float64(), 42.0),
128+
(0, pa.float32(), 0.0),
129+
(-100, pa.float64(), -100.0),
130+
]
131+
self._run_coercion_tests_with_values(int_to_float_cases)
132+
133+
# ---- Integer to Decimal ----
134+
int_to_decimal_cases = [
135+
(42, pa.decimal128(10, 2), Decimal("42.00")),
136+
(-999, pa.decimal128(10, 2), Decimal("-999.00")),
137+
(0, pa.decimal128(10, 2), Decimal("0.00")),
138+
]
139+
self._run_coercion_tests_with_values(int_to_decimal_cases)
140+
141+
# ---- Float to Float ----
142+
float_cases = [
143+
(0.0, pa.float32()),
144+
(0.0, pa.float64()),
145+
(3.14, pa.float32()),
146+
(3.14, pa.float64()),
147+
(-3.14, pa.float32()),
148+
(-3.14, pa.float64()),
149+
(float("inf"), pa.float32()),
150+
(float("inf"), pa.float64()),
151+
(float("-inf"), pa.float32()),
152+
(float("-inf"), pa.float64()),
153+
]
154+
self._run_coercion_tests(float_cases)
155+
156+
# NaN special case
157+
for target_type in [pa.float32(), pa.float64()]:
158+
scalar = pa.scalar(float("nan"), type=target_type)
159+
self.assertEqual(scalar.type, target_type)
160+
self.assertTrue(math.isnan(scalar.as_py()))
161+
162+
# ---- Float to Integer (Truncation) ----
163+
float_to_int_cases = [
164+
(42.9, pa.int64(), 42),
165+
(-42.9, pa.int64(), -42),
166+
(42.0, pa.int64(), 42),
167+
(0.5, pa.int64(), 0),
168+
(-0.5, pa.int64(), 0),
169+
]
170+
self._run_coercion_tests_with_values(float_to_int_cases)
171+
172+
# ---- Decimal to Integer (Truncation) ----
173+
scalar = pa.scalar(Decimal("123.45"), type=pa.int64())
174+
self.assertEqual(scalar.type, pa.int64())
175+
self.assertEqual(scalar.as_py(), 123)
176+
177+
# ---- Decimal to Decimal ----
178+
scalar = pa.scalar(Decimal("123.45"), type=pa.decimal128(20, 5))
179+
self.assertEqual(scalar.type, pa.decimal128(20, 5))
180+
self.assertEqual(scalar.as_py(), Decimal("123.45000"))
181+
182+
def test_numeric_coercion_errors(self):
183+
"""Test numeric coercion error cases."""
184+
import pyarrow as pa
185+
186+
# Integer overflow
187+
overflow_cases = [
188+
(128, pa.int8()),
189+
(-129, pa.int8()),
190+
(256, pa.uint8()),
191+
(32768, pa.int16()),
192+
(2**62, pa.int32()),
193+
]
194+
self._run_error_tests(overflow_cases, pa.ArrowInvalid)
195+
196+
# Negative to unsigned
197+
negative_to_unsigned_cases = [
198+
(-1, pa.uint8()),
199+
(-1, pa.uint16()),
200+
(-1, pa.uint32()),
201+
(-1, pa.uint64()),
202+
]
203+
for value, target_type in negative_to_unsigned_cases:
204+
with self.assertRaises(OverflowError):
205+
pa.scalar(value, type=target_type)
206+
207+
# Integer precision loss in float32 (2^24 + 1)
208+
with self.assertRaises(pa.ArrowInvalid):
209+
pa.scalar(16777217, type=pa.float32())
210+
211+
# NaN/Inf to integer
212+
nan_inf_cases = [
213+
(float("nan"), pa.int64()),
214+
(float("inf"), pa.int64()),
215+
(float("-inf"), pa.int64()),
216+
]
217+
self._run_error_tests(nan_inf_cases, pa.ArrowInvalid)
218+
219+
# Float to decimal
220+
float_to_decimal_cases = [
221+
(42.5, pa.decimal128(10, 2)),
222+
(0.0, pa.decimal128(10, 2)),
223+
(3.14, pa.decimal128(10, 2)),
224+
(float("nan"), pa.decimal128(10, 2)),
225+
]
226+
self._run_error_tests(float_to_decimal_cases, pa.ArrowTypeError)
227+
228+
# Decimal precision loss
229+
with self.assertRaises(pa.ArrowInvalid):
230+
pa.scalar(Decimal("123.456"), type=pa.decimal128(10, 2))
231+
232+
# Decimal to float
233+
with self.assertRaises(pa.ArrowInvalid):
234+
pa.scalar(Decimal("123.45"), type=pa.float64())
235+
236+
237+
if __name__ == "__main__":
238+
from pyspark.testing import main
239+
240+
main()

0 commit comments

Comments
 (0)