Skip to content

Commit d0b154b

Browse files
authored
bug: Fix string decimal type throw right exception (#3248)
1 parent abd0bb8 commit d0b154b

4 files changed

Lines changed: 194 additions & 80 deletions

File tree

native/spark-expr/benches/cast_from_string.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,31 @@ fn criterion_benchmark(c: &mut Criterion) {
6868
b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
6969
});
7070
group.finish();
71+
72+
// str -> decimal benchmark
73+
let decimal_string_batch = create_decimal_cast_string_batch();
74+
for (mode, mode_name) in [
75+
(EvalMode::Legacy, "legacy"),
76+
(EvalMode::Ansi, "ansi"),
77+
(EvalMode::Try, "try"),
78+
] {
79+
let spark_cast_options = SparkCastOptions::new(mode, "", false);
80+
let cast_to_decimal_38_10 = Cast::new(
81+
expr.clone(),
82+
DataType::Decimal128(38, 10),
83+
spark_cast_options,
84+
);
85+
86+
let mut group = c.benchmark_group(format!("cast_string_to_decimal/{}", mode_name));
87+
group.bench_function("decimal_38_10", |b| {
88+
b.iter(|| {
89+
cast_to_decimal_38_10
90+
.evaluate(&decimal_string_batch)
91+
.unwrap()
92+
});
93+
});
94+
group.finish();
95+
}
7196
}
7297

7398
/// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks)
@@ -118,6 +143,51 @@ fn create_decimal_string_batch() -> RecordBatch {
118143
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
119144
}
120145

146+
/// Create batch with decimal strings for string-to-decimal cast perf evaluation
147+
fn create_decimal_cast_string_batch() -> RecordBatch {
148+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
149+
let mut b = StringBuilder::new();
150+
for i in 0..1000 {
151+
if i % 10 == 0 {
152+
b.append_null();
153+
} else {
154+
// Generate various decimal formats
155+
match i % 5 {
156+
0 => {
157+
// gen simple decimals (ex : "123.45"
158+
let int_part: u32 = rand::random::<u32>() % 1000000;
159+
let dec_part: u32 = rand::random::<u32>() % 100000;
160+
b.append_value(format!("{}.{}", int_part, dec_part));
161+
}
162+
1 => {
163+
// gen scientific notation like "123e5"
164+
let mantissa: u32 = rand::random::<u32>() % 1000;
165+
let exp: i8 = (rand::random::<i8>() % 10).abs();
166+
b.append_value(format!("{}.{}E{}", mantissa / 100, mantissa % 100, exp));
167+
}
168+
2 => {
169+
// Negative numbers
170+
let int_part: u32 = rand::random::<u32>() % 1000000;
171+
let dec_part: u32 = rand::random::<u32>() % 100000;
172+
b.append_value(format!("-{}.{}", int_part, dec_part));
173+
}
174+
3 => {
175+
// Ints only
176+
let val: i32 = rand::random::<i32>() % 1000000;
177+
b.append_value(format!("{}", val));
178+
}
179+
_ => {
180+
// Small decimals (ex : 0.001)
181+
let dec_part: u32 = rand::random::<u32>() % 100000;
182+
b.append_value(format!("0.{:05}", dec_part));
183+
}
184+
}
185+
}
186+
}
187+
let array = b.finish();
188+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
189+
}
190+
121191
fn config() -> Criterion {
122192
Criterion::default()
123193
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 97 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2321,8 +2321,8 @@ fn cast_string_to_decimal256_impl(
23212321
}
23222322

23232323
/// Parse a string to decimal following Spark's behavior
2324-
fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
2325-
let string_bytes = s.as_bytes();
2324+
fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
2325+
let string_bytes = input_str.as_bytes();
23262326
let mut start = 0;
23272327
let mut end = string_bytes.len();
23282328

@@ -2334,7 +2334,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
23342334
end -= 1;
23352335
}
23362336

2337-
let trimmed = &s[start..end];
2337+
let trimmed = &input_str[start..end];
23382338

23392339
if trimmed.is_empty() {
23402340
return Ok(None);
@@ -2351,73 +2351,101 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
23512351
return Ok(None);
23522352
}
23532353

2354-
// validate and parse mantissa and exponent
2355-
match parse_decimal_str(trimmed) {
2356-
Ok((mantissa, exponent)) => {
2357-
// Convert to target scale
2358-
let target_scale = scale as i32;
2359-
let scale_adjustment = target_scale - exponent;
2354+
// validate and parse mantissa and exponent or bubble up the error
2355+
let (mantissa, exponent) = parse_decimal_str(trimmed, input_str, precision, scale)?;
23602356

2361-
let scaled_value = if scale_adjustment >= 0 {
2362-
// Need to multiply (increase scale) but return None if scale is too high to fit i128
2363-
if scale_adjustment > 38 {
2364-
return Ok(None);
2365-
}
2366-
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
2367-
} else {
2368-
// Need to multiply (increase scale) but return None if scale is too high to fit i128
2369-
let abs_scale_adjustment = (-scale_adjustment) as u32;
2370-
if abs_scale_adjustment > 38 {
2371-
return Ok(Some(0));
2372-
}
2357+
// Early return mantissa 0, Spark checks if it fits digits and throw error in ansi
2358+
if mantissa == 0 {
2359+
if exponent < -37 {
2360+
return Err(SparkError::NumericOutOfRange {
2361+
value: input_str.to_string(),
2362+
});
2363+
}
2364+
return Ok(Some(0));
2365+
}
23732366

2374-
let divisor = 10_i128.pow(abs_scale_adjustment);
2375-
let quotient_opt = mantissa.checked_div(divisor);
2376-
// Check if divisor is 0
2377-
if quotient_opt.is_none() {
2378-
return Ok(None);
2379-
}
2380-
let quotient = quotient_opt.unwrap();
2381-
let remainder = mantissa % divisor;
2382-
2383-
// Round half up: if abs(remainder) >= divisor/2, round away from zero
2384-
let half_divisor = divisor / 2;
2385-
let rounded = if remainder.abs() >= half_divisor {
2386-
if mantissa >= 0 {
2387-
quotient + 1
2388-
} else {
2389-
quotient - 1
2390-
}
2391-
} else {
2392-
quotient
2393-
};
2394-
Some(rounded)
2395-
};
2367+
// scale adjustment
2368+
let target_scale = scale as i32;
2369+
let scale_adjustment = target_scale - exponent;
23962370

2397-
match scaled_value {
2398-
Some(value) => {
2399-
// Check if it fits target precision
2400-
if is_validate_decimal_precision(value, precision) {
2401-
Ok(Some(value))
2402-
} else {
2403-
Ok(None)
2404-
}
2405-
}
2406-
None => {
2407-
// Overflow while scaling
2408-
Ok(None)
2409-
}
2371+
let scaled_value = if scale_adjustment >= 0 {
2372+
// Need to multiply (increase scale) but return None if scale is too high to fit i128
2373+
if scale_adjustment > 38 {
2374+
return Ok(None);
2375+
}
2376+
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
2377+
} else {
2378+
// Need to divide (decrease scale)
2379+
let abs_scale_adjustment = (-scale_adjustment) as u32;
2380+
if abs_scale_adjustment > 38 {
2381+
return Ok(Some(0));
2382+
}
2383+
2384+
let divisor = 10_i128.pow(abs_scale_adjustment);
2385+
let quotient_opt = mantissa.checked_div(divisor);
2386+
// Check if divisor is 0
2387+
if quotient_opt.is_none() {
2388+
return Ok(None);
2389+
}
2390+
let quotient = quotient_opt.unwrap();
2391+
let remainder = mantissa % divisor;
2392+
2393+
// Round half up: if abs(remainder) >= divisor/2, round away from zero
2394+
let half_divisor = divisor / 2;
2395+
let rounded = if remainder.abs() >= half_divisor {
2396+
if mantissa >= 0 {
2397+
quotient + 1
2398+
} else {
2399+
quotient - 1
2400+
}
2401+
} else {
2402+
quotient
2403+
};
2404+
Some(rounded)
2405+
};
2406+
2407+
match scaled_value {
2408+
Some(value) => {
2409+
if is_validate_decimal_precision(value, precision) {
2410+
Ok(Some(value))
2411+
} else {
2412+
// Value ok but exceeds precision mentioned . THrow error
2413+
Err(SparkError::NumericValueOutOfRange {
2414+
value: trimmed.to_string(),
2415+
precision,
2416+
scale,
2417+
})
24102418
}
24112419
}
2412-
Err(_) => Ok(None),
2420+
None => {
2421+
// Overflow when scaling raise exception
2422+
Err(SparkError::NumericValueOutOfRange {
2423+
value: trimmed.to_string(),
2424+
precision,
2425+
scale,
2426+
})
2427+
}
24132428
}
24142429
}
24152430

2431+
fn invalid_decimal_cast(value: &str, precision: u8, scale: i8) -> SparkError {
2432+
invalid_value(
2433+
value,
2434+
"STRING",
2435+
&format!("DECIMAL({},{})", precision, scale),
2436+
)
2437+
}
2438+
24162439
/// Parse a decimal string into mantissa and scale
2417-
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
2418-
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
2440+
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
2441+
fn parse_decimal_str(
2442+
s: &str,
2443+
original_str: &str,
2444+
precision: u8,
2445+
scale: i8,
2446+
) -> SparkResult<(i128, i32)> {
24192447
if s.is_empty() {
2420-
return Err("Empty string".to_string());
2448+
return Err(invalid_decimal_cast(original_str, precision, scale));
24212449
}
24222450

24232451
let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
@@ -2426,7 +2454,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
24262454
// Parse exponent
24272455
let exp: i32 = exponent_part
24282456
.parse()
2429-
.map_err(|e| format!("Invalid exponent: {}", e))?;
2457+
.map_err(|_| invalid_decimal_cast(original_str, precision, scale))?;
24302458

24312459
(mantissa_part, exp)
24322460
} else {
@@ -2441,29 +2469,29 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
24412469
};
24422470

24432471
if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
2444-
return Err("Invalid sign format".to_string());
2472+
return Err(invalid_decimal_cast(original_str, precision, scale));
24452473
}
24462474

24472475
let (integral_part, fractional_part) = match mantissa_str.find('.') {
24482476
Some(dot_pos) => {
24492477
if mantissa_str[dot_pos + 1..].contains('.') {
2450-
return Err("Multiple decimal points".to_string());
2478+
return Err(invalid_decimal_cast(original_str, precision, scale));
24512479
}
24522480
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
24532481
}
24542482
None => (mantissa_str, ""),
24552483
};
24562484

24572485
if integral_part.is_empty() && fractional_part.is_empty() {
2458-
return Err("No digits found".to_string());
2486+
return Err(invalid_decimal_cast(original_str, precision, scale));
24592487
}
24602488

24612489
if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
2462-
return Err("Invalid integral part".to_string());
2490+
return Err(invalid_decimal_cast(original_str, precision, scale));
24632491
}
24642492

24652493
if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
2466-
return Err("Invalid fractional part".to_string());
2494+
return Err(invalid_decimal_cast(original_str, precision, scale));
24672495
}
24682496

24692497
// Parse integral part
@@ -2473,7 +2501,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
24732501
} else {
24742502
integral_part
24752503
.parse()
2476-
.map_err(|_| "Invalid integral part".to_string())?
2504+
.map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
24772505
};
24782506

24792507
// Parse fractional part
@@ -2483,14 +2511,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
24832511
} else {
24842512
fractional_part
24852513
.parse()
2486-
.map_err(|_| "Invalid fractional part".to_string())?
2514+
.map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
24872515
};
24882516

24892517
// Combine: value = integral * 10^fractional_scale + fractional
24902518
let mantissa = integral_value
24912519
.checked_mul(10_i128.pow(fractional_scale as u32))
24922520
.and_then(|v| v.checked_add(fractional_value))
2493-
.ok_or("Overflow in mantissa calculation")?;
2521+
.ok_or_else(|| invalid_decimal_cast(original_str, precision, scale))?;
24942522

24952523
let final_mantissa = if negative { -mantissa } else { mantissa };
24962524
// final scale = fractional_scale - exponent

native/spark-expr/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ pub enum SparkError {
3939
scale: i8,
4040
},
4141

42+
#[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")]
43+
NumericOutOfRange { value: String },
44+
4245
#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
4346
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
4447
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]

0 commit comments

Comments
 (0)