Skip to content

Commit 8ba3d26

Browse files
kumarUjjawalalamb
andauthored
fix: regression of dict_id in physical plan proto (#20063)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20011. ## Rationale for this change - `dict_id` is intentionally not preserved protobuf (it’s deprecated in Arrow schema metadata), but Arrow IPC still requires dict IDs for dictionary encoding/decoding. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Fix protobuf serde for nested ScalarValue (list/struct/map) containing dictionary arrays by using Arrow IPC’s dictionary handling correctly. - Seed DictionaryTracker by encoding the schema before encoding the nested scalar batch. - On decode, reconstruct an IPC schema from the protobuf schema and use arrow_ipc::reader::read_dictionary to build dict_by_id before reading the record batch. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes added a test for this <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent d3ac7a3 commit 8ba3d26

3 files changed

Lines changed: 85 additions & 26 deletions

File tree

datafusion/proto-common/src/from_proto/mod.rs

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ use arrow::datatypes::{
2828
DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema,
2929
TimeUnit, UnionFields, UnionMode, i256,
3030
};
31-
use arrow::ipc::{reader::read_record_batch, root_as_message};
31+
use arrow::ipc::{
32+
convert::fb_to_schema,
33+
reader::{read_dictionary, read_record_batch},
34+
root_as_message,
35+
writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions},
36+
};
3237

3338
use datafusion_common::{
3439
Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef,
@@ -397,7 +402,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
397402
Value::Float32Value(v) => Self::Float32(Some(*v)),
398403
Value::Float64Value(v) => Self::Float64(Some(*v)),
399404
Value::Date32Value(v) => Self::Date32(Some(*v)),
400-
// ScalarValue::List is serialized using arrow IPC format
405+
// Nested ScalarValue types are serialized using arrow IPC format
401406
Value::ListValue(v)
402407
| Value::FixedSizeListValue(v)
403408
| Value::LargeListValue(v)
@@ -414,55 +419,83 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
414419
schema_ref.try_into()?
415420
} else {
416421
return Err(Error::General(
417-
"Invalid schema while deserializing ScalarValue::List"
422+
"Invalid schema while deserializing nested ScalarValue"
418423
.to_string(),
419424
));
420425
};
421426

427+
// IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf
428+
// `Schema` doesn't preserve those IDs. Reconstruct them deterministically by
429+
// round-tripping the schema through IPC.
430+
let schema: Schema = {
431+
let ipc_gen = IpcDataGenerator {};
432+
let write_options = IpcWriteOptions::default();
433+
let mut dict_tracker = DictionaryTracker::new(false);
434+
let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker(
435+
&schema,
436+
&mut dict_tracker,
437+
&write_options,
438+
);
439+
let message =
440+
root_as_message(encoded_schema.ipc_message.as_slice()).map_err(
441+
|e| {
442+
Error::General(format!(
443+
"Error IPC schema message while deserializing nested ScalarValue: {e}"
444+
))
445+
},
446+
)?;
447+
let ipc_schema = message.header_as_schema().ok_or_else(|| {
448+
Error::General(
449+
"Unexpected message type deserializing nested ScalarValue schema"
450+
.to_string(),
451+
)
452+
})?;
453+
fb_to_schema(ipc_schema)
454+
};
455+
422456
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
423457
Error::General(format!(
424-
"Error IPC message while deserializing ScalarValue::List: {e}"
458+
"Error IPC message while deserializing nested ScalarValue: {e}"
425459
))
426460
})?;
427461
let buffer = Buffer::from(arrow_data.as_slice());
428462

429463
let ipc_batch = message.header_as_record_batch().ok_or_else(|| {
430464
Error::General(
431-
"Unexpected message type deserializing ScalarValue::List"
465+
"Unexpected message type deserializing nested ScalarValue"
432466
.to_string(),
433467
)
434468
})?;
435469

436-
let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
470+
let mut dict_by_id: HashMap<i64, ArrayRef> = HashMap::new();
471+
for protobuf::scalar_nested_value::Dictionary {
472+
ipc_message,
473+
arrow_data,
474+
} in dictionaries
475+
{
437476
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
438477
Error::General(format!(
439-
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
478+
"Error IPC message while deserializing nested ScalarValue dictionary message: {e}"
440479
))
441480
})?;
442481
let buffer = Buffer::from(arrow_data.as_slice());
443482

444483
let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
445484
Error::General(
446-
"Unexpected message type deserializing ScalarValue::List dictionary message"
485+
"Unexpected message type deserializing nested ScalarValue dictionary message"
447486
.to_string(),
448487
)
449488
})?;
450-
451-
let id = dict_batch.id();
452-
453-
let record_batch = read_record_batch(
489+
read_dictionary(
454490
&buffer,
455-
dict_batch.data().unwrap(),
456-
Arc::new(schema.clone()),
457-
&Default::default(),
458-
None,
491+
dict_batch,
492+
&schema,
493+
&mut dict_by_id,
459494
&message.version(),
460-
)?;
461-
462-
let values: ArrayRef = Arc::clone(record_batch.column(0));
463-
464-
Ok((id, values))
465-
}).collect::<datafusion_common::Result<HashMap<_, _>>>()?;
495+
)
496+
.map_err(|e| arrow_datafusion_err!(e))
497+
.map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?;
498+
}
466499

467500
let record_batch = read_record_batch(
468501
&buffer,
@@ -473,7 +506,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
473506
&message.version(),
474507
)
475508
.map_err(|e| arrow_datafusion_err!(e))
476-
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
509+
.map_err(|e| e.context("Decoding nested ScalarValue value"))?;
477510
let arr = record_batch.column(0);
478511
match value {
479512
Value::ListValue(_) => {

datafusion/proto-common/src/to_proto/mod.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,21 +1031,28 @@ fn create_proto_scalar<I, T: FnOnce(&I) -> protobuf::scalar_value::Value>(
10311031
Ok(protobuf::ScalarValue { value: Some(value) })
10321032
}
10331033

1034-
// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using
1034+
// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using
10351035
// Arrow IPC messages as a single column RecordBatch
10361036
fn encode_scalar_nested_value(
10371037
arr: ArrayRef,
10381038
val: &ScalarValue,
10391039
) -> Result<protobuf::ScalarValue, Error> {
10401040
let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| {
10411041
Error::General(format!(
1042-
"Error creating temporary batch while encoding ScalarValue::List: {e}"
1042+
"Error creating temporary batch while encoding nested ScalarValue: {e}"
10431043
))
10441044
})?;
10451045

10461046
let ipc_gen = IpcDataGenerator {};
10471047
let mut dict_tracker = DictionaryTracker::new(false);
10481048
let write_options = IpcWriteOptions::default();
1049+
// The IPC writer requires pre-allocated dictionary IDs (normally assigned when
1050+
// serializing the schema). Populate `dict_tracker` by encoding the schema first.
1051+
ipc_gen.schema_to_bytes_with_dictionary_tracker(
1052+
batch.schema().as_ref(),
1053+
&mut dict_tracker,
1054+
&write_options,
1055+
);
10491056
let mut compression_context = CompressionContext::default();
10501057
let (encoded_dictionaries, encoded_message) = ipc_gen
10511058
.encode(
@@ -1055,7 +1062,7 @@ fn encode_scalar_nested_value(
10551062
&mut compression_context,
10561063
)
10571064
.map_err(|e| {
1058-
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
1065+
Error::General(format!("Error encoding nested ScalarValue as IPC: {e}"))
10591066
})?;
10601067

10611068
let schema: protobuf::Schema = batch.schema().try_into()?;

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,6 +2566,25 @@ fn custom_proto_converter_intercepts() -> Result<()> {
25662566
Ok(())
25672567
}
25682568

2569+
#[test]
2570+
fn roundtrip_call_null_scalar_struct_dict() -> Result<()> {
2571+
let data_type = DataType::Struct(Fields::from(vec![Field::new(
2572+
"item",
2573+
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
2574+
true,
2575+
)]));
2576+
2577+
let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)]));
2578+
let scan = Arc::new(EmptyExec::new(Arc::clone(&schema)));
2579+
let scalar = lit(ScalarValue::try_from(data_type)?);
2580+
let filter = Arc::new(FilterExec::try_new(
2581+
Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)),
2582+
scan,
2583+
)?);
2584+
2585+
roundtrip_test(filter)
2586+
}
2587+
25692588
/// Test that expression deduplication works during deserialization.
25702589
/// When the same expression Arc is serialized multiple times, it should be
25712590
/// deduplicated on deserialization (sharing the same Arc).

0 commit comments

Comments
 (0)