Skip to content

Commit e572a45

Browse files
authored
Fix Unsound Binary Casting in Unreleased Arrow (#3691) (#3692)
* Fix binary casting (#3691) * Clippy * More clippy * Update test
1 parent 3761ac5 commit e572a45

File tree

3 files changed

+69
-61
lines changed

3 files changed

+69
-61
lines changed

arrow-array/src/array/string_array.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::{
2121
};
2222
use arrow_buffer::{bit_util, MutableBuffer};
2323
use arrow_data::ArrayData;
24-
use arrow_schema::DataType;
24+
use arrow_schema::{ArrowError, DataType};
2525

2626
/// Generic struct for \[Large\]StringArray
2727
///
@@ -99,6 +99,34 @@ impl<OffsetSize: OffsetSizeTrait> GenericStringArray<OffsetSize> {
9999
) -> impl Iterator<Item = Option<&str>> + 'a {
100100
indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index)))
101101
}
102+
103+
/// Fallibly creates a [`GenericStringArray`] from a [`GenericBinaryArray`] returning
104+
/// an error if [`GenericBinaryArray`] contains invalid UTF-8 data
105+
pub fn try_from_binary(
106+
v: GenericBinaryArray<OffsetSize>,
107+
) -> Result<Self, ArrowError> {
108+
let offsets = v.value_offsets();
109+
let values = v.value_data();
110+
111+
// We only need to validate that all values are valid UTF-8
112+
let validated = std::str::from_utf8(values).map_err(|e| {
113+
ArrowError::CastError(format!("Encountered non UTF-8 data: {e}"))
114+
})?;
115+
116+
for offset in offsets.iter() {
117+
let o = offset.as_usize();
118+
if !validated.is_char_boundary(o) {
119+
return Err(ArrowError::CastError(format!(
120+
"Split UTF-8 codepoint at offset {o}"
121+
)));
122+
}
123+
}
124+
125+
let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE);
126+
// SAFETY:
127+
// Validated UTF-8 above
128+
Ok(Self::from(unsafe { builder.build_unchecked() }))
129+
}
102130
}
103131

104132
impl<'a, Ptr, OffsetSize: OffsetSizeTrait> FromIterator<&'a Option<Ptr>>
@@ -172,22 +200,7 @@ impl<OffsetSize: OffsetSizeTrait> From<GenericBinaryArray<OffsetSize>>
172200
for GenericStringArray<OffsetSize>
173201
{
174202
fn from(v: GenericBinaryArray<OffsetSize>) -> Self {
175-
let offsets = v.value_offsets();
176-
let values = v.value_data();
177-
178-
// We only need to validate that all values are valid UTF-8
179-
let validated = std::str::from_utf8(values).expect("Invalid UTF-8 sequence");
180-
for offset in offsets.iter() {
181-
assert!(
182-
validated.is_char_boundary(offset.as_usize()),
183-
"Invalid UTF-8 sequence"
184-
)
185-
}
186-
187-
let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE);
188-
// SAFETY:
189-
// Validated UTF-8 above
190-
Self::from(unsafe { builder.build_unchecked() })
203+
Self::try_from_binary(v).unwrap()
191204
}
192205
}
193206

@@ -650,7 +663,9 @@ mod tests {
650663
}
651664

652665
#[test]
653-
#[should_panic(expected = "Invalid UTF-8 sequence: Utf8Error")]
666+
#[should_panic(
667+
expected = "Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 0"
668+
)]
654669
fn test_list_array_utf8_validation() {
655670
let mut builder = ListBuilder::new(PrimitiveBuilder::<UInt8Type>::new());
656671
builder.values().append_value(0xFF);

arrow-cast/src/cast.rs

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3202,49 +3202,25 @@ fn cast_binary_to_string<O: OffsetSizeTrait>(
32023202
.downcast_ref::<GenericByteArray<GenericBinaryType<O>>>()
32033203
.unwrap();
32043204

3205-
if !cast_options.safe {
3206-
let offsets = array.value_offsets();
3207-
let values = array.value_data();
3208-
3209-
// We only need to validate that all values are valid UTF-8
3210-
let validated = std::str::from_utf8(values)
3211-
.map_err(|_| ArrowError::CastError("Invalid UTF-8 sequence".to_string()))?;
3212-
// Checks if the offsets are valid but does not re-encode
3213-
for offset in offsets.iter() {
3214-
if !validated.is_char_boundary(offset.as_usize()) {
3215-
return Err(ArrowError::CastError("Invalid UTF-8 sequence".to_string()));
3205+
match GenericStringArray::<O>::try_from_binary(array.clone()) {
3206+
Ok(a) => Ok(Arc::new(a)),
3207+
Err(e) => match cast_options.safe {
3208+
true => {
3209+
// Fallback to slow method to convert invalid sequences to nulls
3210+
let mut builder = GenericStringBuilder::<O>::with_capacity(
3211+
array.len(),
3212+
array.value_data().len(),
3213+
);
3214+
3215+
let iter = array
3216+
.iter()
3217+
.map(|v| v.and_then(|v| std::str::from_utf8(v).ok()));
3218+
3219+
builder.extend(iter);
3220+
Ok(Arc::new(builder.finish()))
32163221
}
3217-
}
3218-
3219-
let builder = array
3220-
.into_data()
3221-
.into_builder()
3222-
.data_type(GenericStringArray::<O>::DATA_TYPE);
3223-
// SAFETY:
3224-
// Validated UTF-8 above
3225-
Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
3226-
builder.build_unchecked()
3227-
})))
3228-
} else {
3229-
let mut null_builder = BooleanBufferBuilder::new(array.len());
3230-
array.iter().for_each(|maybe_value| {
3231-
null_builder.append(
3232-
maybe_value
3233-
.and_then(|value| std::str::from_utf8(value).ok())
3234-
.is_some(),
3235-
);
3236-
});
3237-
3238-
let builder = array
3239-
.into_data()
3240-
.into_builder()
3241-
.null_bit_buffer(Some(null_builder.finish()))
3242-
.data_type(GenericStringArray::<O>::DATA_TYPE);
3243-
// SAFETY:
3244-
// Validated UTF-8 above
3245-
Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
3246-
builder.build_unchecked()
3247-
})))
3222+
false => Err(e),
3223+
},
32483224
}
32493225
}
32503226

@@ -7588,4 +7564,21 @@ mod tests {
75887564
test_tz("+00:00".to_owned());
75897565
test_tz("+02:00".to_owned());
75907566
}
7567+
7568+
#[test]
7569+
fn test_cast_invalid_utf8() {
7570+
let v1: &[u8] = b"\xFF invalid";
7571+
let v2: &[u8] = b"\x00 Foo";
7572+
let s = BinaryArray::from(vec![v1, v2]);
7573+
let options = CastOptions { safe: true };
7574+
let array = cast_with_options(&s, &DataType::Utf8, &options).unwrap();
7575+
let a = as_string_array(array.as_ref());
7576+
a.data().validate_full().unwrap();
7577+
7578+
assert_eq!(a.null_count(), 1);
7579+
assert_eq!(a.len(), 2);
7580+
assert!(a.is_null(0));
7581+
assert_eq!(a.value(0), "");
7582+
assert_eq!(a.value(1), "\x00 Foo");
7583+
}
75917584
}

arrow-row/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ mod tests {
17341734
}
17351735

17361736
#[test]
1737-
#[should_panic(expected = "Invalid UTF-8 sequence")]
1737+
#[should_panic(expected = "Encountered non UTF-8 data")]
17381738
fn test_invalid_utf8() {
17391739
let mut converter =
17401740
RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();

0 commit comments

Comments
 (0)