Skip to content

Commit 50b3509

Browse files
committed
Add test for runtime memory limiting
1 parent 2ff37a4 commit 50b3509

6 files changed

Lines changed: 152 additions & 13 deletions

File tree

datafusion/core/src/execution/memory_manager/proxy.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ impl MemoryConsumer for MemoryConsumerProxy {
8888
}
8989

9090
async fn spill(&self) -> Result<usize, DataFusionError> {
91-
Err(DataFusionError::ResourcesExhausted(
92-
"Cannot spill AggregationState".to_owned(),
93-
))
91+
Err(DataFusionError::ResourcesExhausted(format!(
92+
"Cannot spill {}",
93+
self.name
94+
)))
9495
}
9596

9697
fn mem_used(&self) -> usize {

datafusion/core/src/physical_plan/aggregates/hash.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl GroupedHashAggregateStream {
135135
aggregate_expressions,
136136
accumulators: Accumulators {
137137
memory_consumer: MemoryConsumerProxy::new(
138-
"Accumulators",
138+
"GroupBy Hash Accumulators",
139139
MemoryConsumerId::new(partition),
140140
Arc::clone(&context.runtime_env().memory_manager),
141141
),

datafusion/core/src/physical_plan/aggregates/row_hash.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ impl GroupedHashAggregateStreamV2 {
144144

145145
let aggr_state = AggregationState {
146146
memory_consumer: MemoryConsumerProxy::new(
147-
"AggregationState",
147+
"GroupBy Hash (Row) AggregationState",
148148
MemoryConsumerId::new(partition),
149149
Arc::clone(&context.runtime_env().memory_manager),
150150
),

datafusion/core/src/physical_plan/sorts/sort.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ impl ExternalSorter {
118118
) -> Result<()> {
119119
if input.num_rows() > 0 {
120120
let size = batch_byte_size(&input);
121+
debug!("Inserting {} rows of {} bytes", input.num_rows(), size);
121122
self.try_grow(size).await?;
122123
self.metrics.mem_used().add(size);
123124
let mut in_mem_batches = self.in_mem_batches.lock().await;
@@ -272,6 +273,13 @@ impl MemoryConsumer for ExternalSorter {
272273
}
273274

274275
async fn spill(&self) -> Result<usize> {
276+
let partition = self.partition_id();
277+
let mut in_mem_batches = self.in_mem_batches.lock().await;
278+
// we could always get a chance to free some memory as long as we are holding some
279+
if in_mem_batches.len() == 0 {
280+
return Ok(0);
281+
}
282+
275283
debug!(
276284
"{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)",
277285
self.name(),
@@ -280,13 +288,6 @@ impl MemoryConsumer for ExternalSorter {
280288
self.spill_count()
281289
);
282290

283-
let partition = self.partition_id();
284-
let mut in_mem_batches = self.in_mem_batches.lock().await;
285-
// we could always get a chance to free some memory as long as we are holding some
286-
if in_mem_batches.len() == 0 {
287-
return Ok(0);
288-
}
289-
290291
let tracking_metrics = self
291292
.metrics_set
292293
.new_intermediate_tracking(partition, self.runtime.clone());
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! This module contains tests for limiting memory at runtime in DataFusion
19+
20+
use std::sync::Arc;
21+
22+
use arrow::record_batch::RecordBatch;
23+
use datafusion::datasource::MemTable;
24+
use datafusion::execution::disk_manager::DiskManagerConfig;
25+
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
26+
use datafusion_common::assert_contains;
27+
28+
use datafusion::prelude::{SessionConfig, SessionContext};
29+
use test_utils::{stagger_batch, AccessLogGenerator};
30+
31+
#[cfg(test)]
32+
#[ctor::ctor]
33+
fn init() {
34+
let _ = env_logger::try_init();
35+
}
36+
37+
#[tokio::test]
38+
async fn oom_sort() {
39+
run_limit_test(
40+
"select * from t order by host DESC",
41+
"Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)",
42+
)
43+
.await
44+
}
45+
46+
#[tokio::test]
47+
async fn group_by_none() {
48+
run_limit_test(
49+
"select median(image) from t",
50+
"Resources exhausted: Cannot spill AggregationState",
51+
)
52+
.await
53+
}
54+
55+
#[tokio::test]
56+
async fn group_by_row_hash() {
57+
run_limit_test(
58+
"select count(*) from t GROUP BY response_bytes",
59+
"Resources exhausted: Cannot spill GroupBy Hash (Row) AggregationState",
60+
)
61+
.await
62+
}
63+
64+
#[tokio::test]
65+
async fn group_by_hash() {
66+
run_limit_test(
67+
// group by dict column
68+
"select count(*) from t GROUP BY service, host, pod, container",
69+
"Resources exhausted: Cannot spill GroupBy Hash Accumulators",
70+
)
71+
.await
72+
}
73+
74+
/// 100K memory limit
75+
const MEMORY_LIMIT_BYTES: usize = 50;
76+
const MEMORY_FRACTION: f64 = 0.95;
77+
78+
/// runs the specified query against 1000 rows with a 50
79+
/// byte memory limit and no disk manager enabled.
80+
async fn run_limit_test(query: &str, expected_error: &str) {
81+
let generator = AccessLogGenerator::new().with_row_limit(Some(1000));
82+
83+
let batches: Vec<RecordBatch> = generator
84+
// split up into more than one batch, as the size limit in sort is not enforced until the second batch
85+
.flat_map(stagger_batch)
86+
.collect();
87+
88+
let table = MemTable::try_new(batches[0].schema(), vec![batches]).unwrap();
89+
90+
let rt_config = RuntimeConfig::new()
91+
// do not allow spilling
92+
.with_disk_manager(DiskManagerConfig::Disabled)
93+
// Only allow 50 bytes
94+
.with_memory_limit(MEMORY_LIMIT_BYTES, MEMORY_FRACTION);
95+
96+
let runtime = RuntimeEnv::new(rt_config).unwrap();
97+
98+
let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime));
99+
ctx.register_table("t", Arc::new(table))
100+
.expect("registering table");
101+
102+
let df = ctx.sql(query).await.expect("Planning query");
103+
104+
match df.collect().await {
105+
Ok(_batches) => {
106+
panic!("Unexpected success when running, expected memory limit failure")
107+
}
108+
Err(e) => {
109+
assert_contains!(e.to_string(), expected_error);
110+
}
111+
}
112+
}

test-utils/src/lib.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
use arrow::record_batch::RecordBatch;
2020
use datafusion_common::cast::as_int32_array;
2121
use rand::prelude::StdRng;
22-
use rand::Rng;
22+
use rand::{Rng, SeedableRng};
2323

2424
mod data_gen;
2525

@@ -68,3 +68,28 @@ pub fn add_empty_batches(
6868
})
6969
.collect()
7070
}
71+
72+
/// "stagger" batches: split the batches into random sized batches
73+
pub fn stagger_batch(batch: RecordBatch) -> Vec<RecordBatch> {
74+
let seed = 42;
75+
stagger_batch_with_seed(batch, seed)
76+
}
77+
78+
/// "stagger" batches: split the batches into random sized batches
79+
/// using the specified value for a rng seed
80+
pub fn stagger_batch_with_seed(batch: RecordBatch, seed: u64) -> Vec<RecordBatch> {
81+
let mut batches = vec![];
82+
83+
// use a random number generator to pick a random sized output
84+
let mut rng = StdRng::seed_from_u64(seed);
85+
86+
let mut remainder = batch;
87+
while remainder.num_rows() > 0 {
88+
let batch_size = rng.gen_range(0..remainder.num_rows() + 1);
89+
90+
batches.push(remainder.slice(0, batch_size));
91+
remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size);
92+
}
93+
94+
batches
95+
}

0 commit comments

Comments
 (0)