Skip to content

Commit 6477c2a

Browse files
alambgruuya
andauthored
[branch-52] fix: use spill writer's schema instead of the first batch schema for spill files (#21293) (#21403)
- Part of #21078 - Closes #21293 on branch-52 This PR: - Backports #21293 from @gruuya to the branch-52 line Co-authored-by: Marko Grujic <markoog@gmail.com>
1 parent 6e90ea8 commit 6477c2a

3 files changed

Lines changed: 243 additions & 2 deletions

File tree

datafusion/core/tests/memory_limit/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::sync::{Arc, LazyLock};
2424
#[cfg(feature = "extended_tests")]
2525
mod memory_limit_validation;
2626
mod repartition_mem_limit;
27+
mod union_nullable_spill;
2728
use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray};
2829
use arrow::compute::SortOptions;
2930
use arrow::datatypes::{Int32Type, SchemaRef};
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, Int64Array, RecordBatch};
21+
use arrow::compute::SortOptions;
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use datafusion::datasource::memory::MemorySourceConfig;
24+
use datafusion_execution::config::SessionConfig;
25+
use datafusion_execution::memory_pool::FairSpillPool;
26+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
27+
use datafusion_physical_expr::expressions::col;
28+
use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr};
29+
use datafusion_physical_plan::repartition::RepartitionExec;
30+
use datafusion_physical_plan::sorts::sort::sort_batch;
31+
use datafusion_physical_plan::union::UnionExec;
32+
use datafusion_physical_plan::{ExecutionPlan, Partitioning};
33+
use futures::StreamExt;
34+
35+
const NUM_BATCHES: usize = 200;
36+
const ROWS_PER_BATCH: usize = 10;
37+
38+
fn non_nullable_schema() -> Arc<Schema> {
39+
Arc::new(Schema::new(vec![
40+
Field::new("key", DataType::Int64, false),
41+
Field::new("val", DataType::Int64, false),
42+
]))
43+
}
44+
45+
fn nullable_schema() -> Arc<Schema> {
46+
Arc::new(Schema::new(vec![
47+
Field::new("key", DataType::Int64, false),
48+
Field::new("val", DataType::Int64, true),
49+
]))
50+
}
51+
52+
fn non_nullable_batches() -> Vec<RecordBatch> {
53+
(0..NUM_BATCHES)
54+
.map(|i| {
55+
let start = (i * ROWS_PER_BATCH) as i64;
56+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
57+
RecordBatch::try_new(
58+
non_nullable_schema(),
59+
vec![
60+
Arc::new(Int64Array::from(keys)),
61+
Arc::new(Int64Array::from(vec![0i64; ROWS_PER_BATCH])),
62+
],
63+
)
64+
.unwrap()
65+
})
66+
.collect()
67+
}
68+
69+
fn nullable_batches() -> Vec<RecordBatch> {
70+
(0..NUM_BATCHES)
71+
.map(|i| {
72+
let start = (i * ROWS_PER_BATCH) as i64;
73+
let keys: Vec<i64> = (start..start + ROWS_PER_BATCH as i64).collect();
74+
let vals: Vec<Option<i64>> = (0..ROWS_PER_BATCH)
75+
.map(|j| if j % 3 == 1 { None } else { Some(j as i64) })
76+
.collect();
77+
RecordBatch::try_new(
78+
nullable_schema(),
79+
vec![
80+
Arc::new(Int64Array::from(keys)),
81+
Arc::new(Int64Array::from(vals)),
82+
],
83+
)
84+
.unwrap()
85+
})
86+
.collect()
87+
}
88+
89+
fn build_task_ctx(pool_size: usize) -> Arc<datafusion_execution::TaskContext> {
90+
let session_config = SessionConfig::new().with_batch_size(2);
91+
let runtime = RuntimeEnvBuilder::new()
92+
.with_memory_pool(Arc::new(FairSpillPool::new(pool_size)))
93+
.build_arc()
94+
.unwrap();
95+
Arc::new(
96+
datafusion_execution::TaskContext::default()
97+
.with_session_config(session_config)
98+
.with_runtime(runtime),
99+
)
100+
}
101+
102+
/// Exercises spilling through UnionExec -> RepartitionExec where union children
103+
/// have mismatched nullability (one child's `val` is non-nullable, the other's
104+
/// is nullable with NULLs). A tiny FairSpillPool forces all batches to spill.
105+
///
106+
/// UnionExec returns child streams without schema coercion, so batches from
107+
/// different children carry different per-field nullability into the shared
108+
/// SpillPool. The IPC writer must use the SpillManager's canonical (nullable)
109+
/// schema — not the first batch's schema — so readback batches are valid.
110+
///
111+
/// Otherwise, sort_batch will panic with
112+
/// `Column 'val' is declared as non-nullable but contains null values`
113+
#[tokio::test]
114+
async fn test_sort_union_repartition_spill_mixed_nullability() {
115+
let non_nullable_exec = MemorySourceConfig::try_new_exec(
116+
&[non_nullable_batches()],
117+
non_nullable_schema(),
118+
None,
119+
)
120+
.unwrap();
121+
122+
let nullable_exec =
123+
MemorySourceConfig::try_new_exec(&[nullable_batches()], nullable_schema(), None)
124+
.unwrap();
125+
126+
let union_exec = UnionExec::try_new(vec![non_nullable_exec, nullable_exec]).unwrap();
127+
assert!(union_exec.schema().field(1).is_nullable());
128+
129+
let repartition = Arc::new(
130+
RepartitionExec::try_new(union_exec, Partitioning::RoundRobinBatch(1)).unwrap(),
131+
);
132+
133+
let task_ctx = build_task_ctx(200);
134+
let mut stream = repartition.execute(0, task_ctx).unwrap();
135+
136+
let sort_expr = LexOrdering::new(vec![PhysicalSortExpr {
137+
expr: col("key", &nullable_schema()).unwrap(),
138+
options: SortOptions::default(),
139+
}])
140+
.unwrap();
141+
142+
let mut total_rows = 0usize;
143+
let mut total_nulls = 0usize;
144+
while let Some(result) = stream.next().await {
145+
let batch = result.unwrap();
146+
147+
let batch = sort_batch(&batch, &sort_expr, None).unwrap();
148+
149+
total_rows += batch.num_rows();
150+
total_nulls += batch.column(1).null_count();
151+
}
152+
153+
assert_eq!(
154+
total_rows,
155+
NUM_BATCHES * ROWS_PER_BATCH * 2,
156+
"All rows from both UNION branches should be present"
157+
);
158+
assert!(
159+
total_nulls > 0,
160+
"Expected some null values in output (i.e. nullable batches were processed)"
161+
);
162+
}

datafusion/physical-plan/src/spill/in_progress_spill_file.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ impl InProgressSpillFile {
6262
));
6363
}
6464
if self.writer.is_none() {
65-
let schema = batch.schema();
66-
if let Some(ref in_progress_file) = self.in_progress_file {
65+
// Use the SpillManager's declared schema rather than the batch's schema.
66+
// Individual batches may have different schemas (e.g., different nullability)
67+
// when they come from different branches of a UnionExec. The SpillManager's
68+
// schema represents the canonical schema that all batches should conform to.
69+
let schema = self.spill_writer.schema();
70+
if let Some(in_progress_file) = &mut self.in_progress_file {
6771
self.writer = Some(IPCStreamWriter::new(
6872
in_progress_file.path(),
6973
schema.as_ref(),
@@ -121,3 +125,77 @@ impl InProgressSpillFile {
121125
Ok(self.in_progress_file.take())
122126
}
123127
}
128+
129+
#[cfg(test)]
130+
mod tests {
131+
use super::*;
132+
use arrow::array::Int64Array;
133+
use arrow_schema::{DataType, Field, Schema};
134+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
135+
use datafusion_physical_expr_common::metrics::{
136+
ExecutionPlanMetricsSet, SpillMetrics,
137+
};
138+
use futures::TryStreamExt;
139+
140+
#[tokio::test]
141+
async fn test_spill_file_uses_spill_manager_schema() -> Result<()> {
142+
let nullable_schema = Arc::new(Schema::new(vec![
143+
Field::new("key", DataType::Int64, false),
144+
Field::new("val", DataType::Int64, true),
145+
]));
146+
let non_nullable_schema = Arc::new(Schema::new(vec![
147+
Field::new("key", DataType::Int64, false),
148+
Field::new("val", DataType::Int64, false),
149+
]));
150+
151+
let runtime = Arc::new(RuntimeEnvBuilder::new().build()?);
152+
let metrics_set = ExecutionPlanMetricsSet::new();
153+
let spill_metrics = SpillMetrics::new(&metrics_set, 0);
154+
let spill_manager = Arc::new(SpillManager::new(
155+
runtime,
156+
spill_metrics,
157+
Arc::clone(&nullable_schema),
158+
));
159+
160+
let mut in_progress = spill_manager.create_in_progress_file("test")?;
161+
162+
// First batch: non-nullable val (simulates literal-0 UNION branch)
163+
let non_nullable_batch = RecordBatch::try_new(
164+
Arc::clone(&non_nullable_schema),
165+
vec![
166+
Arc::new(Int64Array::from(vec![1, 2, 3])),
167+
Arc::new(Int64Array::from(vec![0, 0, 0])),
168+
],
169+
)?;
170+
in_progress.append_batch(&non_nullable_batch)?;
171+
172+
// Second batch: nullable val with NULLs (simulates table UNION branch)
173+
let nullable_batch = RecordBatch::try_new(
174+
Arc::clone(&nullable_schema),
175+
vec![
176+
Arc::new(Int64Array::from(vec![4, 5, 6])),
177+
Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])),
178+
],
179+
)?;
180+
in_progress.append_batch(&nullable_batch)?;
181+
182+
let spill_file = in_progress.finish()?.unwrap();
183+
184+
let stream = spill_manager.read_spill_as_stream(spill_file, None)?;
185+
186+
// Stream schema should be nullable
187+
assert_eq!(stream.schema(), nullable_schema);
188+
189+
let batches = stream.try_collect::<Vec<_>>().await?;
190+
assert_eq!(batches.len(), 2);
191+
192+
// Both batches must have the SpillManager's nullable schema
193+
assert_eq!(
194+
batches[0],
195+
non_nullable_batch.with_schema(Arc::clone(&nullable_schema))?
196+
);
197+
assert_eq!(batches[1], nullable_batch);
198+
199+
Ok(())
200+
}
201+
}

0 commit comments

Comments
 (0)