Skip to content

Commit dbf7781

Browse files
xudong9632010YOUY01martin-g
authored
Cherry pick limit pruning from upstream (#29)
Co-authored-by: Yongting You <2010youy01@gmail.com> Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com>
1 parent 5f37deb commit dbf7781

23 files changed

Lines changed: 749 additions & 137 deletions

File tree

datafusion/core/tests/parquet/mod.rs

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use arrow::{
3030
record_batch::RecordBatch,
3131
util::pretty::pretty_format_batches,
3232
};
33+
use arrow_schema::SchemaRef;
3334
use chrono::{Datelike, Duration, TimeDelta};
3435
use datafusion::{
3536
datasource::{provider_as_source, TableProvider},
@@ -109,6 +110,26 @@ struct ContextWithParquet {
109110
ctx: SessionContext,
110111
}
111112

113+
struct PruningMetric {
114+
total_pruned: usize,
115+
total_matched: usize,
116+
total_fully_matched: usize,
117+
}
118+
119+
impl PruningMetric {
120+
pub fn total_pruned(&self) -> usize {
121+
self.total_pruned
122+
}
123+
124+
pub fn total_matched(&self) -> usize {
125+
self.total_matched
126+
}
127+
128+
pub fn total_fully_matched(&self) -> usize {
129+
self.total_fully_matched
130+
}
131+
}
132+
112133
/// The output of running one of the test cases
113134
struct TestOutput {
114135
/// The input query SQL
@@ -126,8 +147,8 @@ struct TestOutput {
126147
impl TestOutput {
127148
/// retrieve the value of the named metric, if any
128149
fn metric_value(&self, metric_name: &str) -> Option<usize> {
129-
if let Some((pruned, _matched)) = self.pruning_metric(metric_name) {
130-
return Some(pruned);
150+
if let Some(pm) = self.pruning_metric(metric_name) {
151+
return Some(pm.total_pruned());
131152
}
132153

133154
self.parquet_metrics
@@ -140,9 +161,10 @@ impl TestOutput {
140161
})
141162
}
142163

143-
fn pruning_metric(&self, metric_name: &str) -> Option<(usize, usize)> {
164+
fn pruning_metric(&self, metric_name: &str) -> Option<PruningMetric> {
144165
let mut total_pruned = 0;
145166
let mut total_matched = 0;
167+
let mut total_fully_matched = 0;
146168
let mut found = false;
147169

148170
for metric in self.parquet_metrics.iter() {
@@ -154,13 +176,19 @@ impl TestOutput {
154176
{
155177
total_pruned += pruning_metrics.pruned();
156178
total_matched += pruning_metrics.matched();
179+
total_fully_matched += pruning_metrics.fully_matched();
180+
157181
found = true;
158182
}
159183
}
160184
}
161185

162186
if found {
163-
Some((total_pruned, total_matched))
187+
Some(PruningMetric {
188+
total_pruned,
189+
total_matched,
190+
total_fully_matched,
191+
})
164192
} else {
165193
None
166194
}
@@ -172,39 +200,33 @@ impl TestOutput {
172200
}
173201

174202
/// The number of row_groups pruned / matched by bloom filter
175-
fn row_groups_bloom_filter(&self) -> Option<(usize, usize)> {
203+
fn row_groups_bloom_filter(&self) -> Option<PruningMetric> {
176204
self.pruning_metric("row_groups_pruned_bloom_filter")
177205
}
178206

179207
/// The number of row_groups matched by statistics
180208
fn row_groups_matched_statistics(&self) -> Option<usize> {
181209
self.pruning_metric("row_groups_pruned_statistics")
182-
.map(|(_pruned, matched)| matched)
210+
.map(|pm| pm.total_matched())
183211
}
184212

185-
/*
186213
/// The number of row_groups fully matched by statistics
187214
fn row_groups_fully_matched_statistics(&self) -> Option<usize> {
188-
self.metric_value("row_groups_fully_matched_statistics")
189-
}
190-
191-
/// The number of row groups pruned by limit pruning
192-
fn limit_pruned_row_groups(&self) -> Option<usize> {
193-
self.metric_value("limit_pruned_row_groups")
215+
self.pruning_metric("row_groups_pruned_statistics")
216+
.map(|pm| pm.total_fully_matched())
194217
}
195-
*/
196218

197219
/// The number of row_groups pruned by statistics
198220
fn row_groups_pruned_statistics(&self) -> Option<usize> {
199221
self.pruning_metric("row_groups_pruned_statistics")
200-
.map(|(pruned, _matched)| pruned)
222+
.map(|pm| pm.total_pruned())
201223
}
202224

203225
/// Metric `files_ranges_pruned_statistics` tracks both pruned and matched count,
204226
/// for testing purpose, here it only aggregate the `pruned` count.
205227
fn files_ranges_pruned_statistics(&self) -> Option<usize> {
206228
self.pruning_metric("files_ranges_pruned_statistics")
207-
.map(|(pruned, _matched)| pruned)
229+
.map(|pm| pm.total_pruned())
208230
}
209231

210232
/// The number of row_groups matched by bloom filter or statistics
@@ -213,22 +235,27 @@ impl TestOutput {
213235
/// filter: 7 total -> 3 matched, this function returns 3 for the final matched
214236
/// count.
215237
fn row_groups_matched(&self) -> Option<usize> {
216-
self.row_groups_bloom_filter()
217-
.map(|(_pruned, matched)| matched)
238+
self.row_groups_bloom_filter().map(|pm| pm.total_matched())
218239
}
219240

220241
/// The number of row_groups pruned
221242
fn row_groups_pruned(&self) -> Option<usize> {
222243
self.row_groups_bloom_filter()
223-
.map(|(pruned, _matched)| pruned)
244+
.map(|pm| pm.total_pruned())
224245
.zip(self.row_groups_pruned_statistics())
225246
.map(|(a, b)| a + b)
226247
}
227248

228249
/// The number of row pages pruned
229250
fn row_pages_pruned(&self) -> Option<usize> {
230251
self.pruning_metric("page_index_rows_pruned")
231-
.map(|(pruned, _matched)| pruned)
252+
.map(|pm| pm.total_pruned())
253+
}
254+
255+
/// The number of row groups pruned by limit pruning
256+
fn limit_pruned_row_groups(&self) -> Option<usize> {
257+
self.pruning_metric("limit_pruned_row_groups")
258+
.map(|pm| pm.total_pruned())
232259
}
233260

234261
fn description(&self) -> String {
@@ -247,6 +274,23 @@ impl ContextWithParquet {
247274
Self::with_config(scenario, unit, SessionConfig::new(), None, None).await
248275
}
249276

277+
/// Set custom schema and batches for the test
278+
pub async fn with_custom_data(
279+
scenario: Scenario,
280+
unit: Unit,
281+
schema: Arc<Schema>,
282+
batches: Vec<RecordBatch>,
283+
) -> Self {
284+
Self::with_config(
285+
scenario,
286+
unit,
287+
SessionConfig::new(),
288+
Some(schema),
289+
Some(batches),
290+
)
291+
.await
292+
}
293+
250294
// Set custom schema and batches for the test
251295
/*
252296
pub async fn with_custom_data(
@@ -270,7 +314,7 @@ impl ContextWithParquet {
270314
scenario: Scenario,
271315
unit: Unit,
272316
mut config: SessionConfig,
273-
custom_schema: Option<Arc<Schema>>,
317+
custom_schema: Option<SchemaRef>,
274318
custom_batches: Option<Vec<RecordBatch>>,
275319
) -> Self {
276320
// Use a single partition for deterministic results no matter how many CPUs the host has
@@ -1109,7 +1153,7 @@ fn create_data_batch(scenario: Scenario) -> Vec<RecordBatch> {
11091153
async fn make_test_file_rg(
11101154
scenario: Scenario,
11111155
row_per_group: usize,
1112-
custom_schema: Option<Arc<Schema>>,
1156+
custom_schema: Option<SchemaRef>,
11131157
custom_batches: Option<Vec<RecordBatch>>,
11141158
) -> NamedTempFile {
11151159
let mut output_file = tempfile::Builder::new()

datafusion/core/tests/parquet/row_group_pruning.rs

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
//! This file contains an end to end test of parquet pruning. It writes
1919
//! data into a parquet file and then verifies row groups are pruned as
2020
//! expected.
21+
use std::sync::Arc;
22+
23+
use arrow::array::{ArrayRef, Int32Array, RecordBatch};
24+
use arrow_schema::{DataType, Field, Schema};
2125
use datafusion::prelude::SessionConfig;
22-
use datafusion_common::ScalarValue;
26+
use datafusion_common::{DataFusionError, ScalarValue};
2327
use itertools::Itertools;
2428

2529
use crate::parquet::Unit::RowGroup;
@@ -30,12 +34,12 @@ struct RowGroupPruningTest {
3034
query: String,
3135
expected_errors: Option<usize>,
3236
expected_row_group_matched_by_statistics: Option<usize>,
33-
// expected_row_group_fully_matched_by_statistics: Option<usize>,
37+
expected_row_group_fully_matched_by_statistics: Option<usize>,
3438
expected_row_group_pruned_by_statistics: Option<usize>,
3539
expected_files_pruned_by_statistics: Option<usize>,
3640
expected_row_group_matched_by_bloom_filter: Option<usize>,
3741
expected_row_group_pruned_by_bloom_filter: Option<usize>,
38-
// expected_limit_pruned_row_groups: Option<usize>,
42+
expected_limit_pruned_row_groups: Option<usize>,
3943
expected_rows: usize,
4044
}
4145
impl RowGroupPruningTest {
@@ -47,11 +51,11 @@ impl RowGroupPruningTest {
4751
expected_errors: None,
4852
expected_row_group_matched_by_statistics: None,
4953
expected_row_group_pruned_by_statistics: None,
50-
// expected_row_group_fully_matched_by_statistics: None,
54+
expected_row_group_fully_matched_by_statistics: None,
5155
expected_files_pruned_by_statistics: None,
5256
expected_row_group_matched_by_bloom_filter: None,
5357
expected_row_group_pruned_by_bloom_filter: None,
54-
// expected_limit_pruned_row_groups: None,
58+
expected_limit_pruned_row_groups: None,
5559
expected_rows: 0,
5660
}
5761
}
@@ -81,7 +85,6 @@ impl RowGroupPruningTest {
8185
}
8286

8387
// Set the expected fully matched row groups by statistics
84-
/*
8588
fn with_fully_matched_by_stats(
8689
mut self,
8790
fully_matched_by_stats: Option<usize>,
@@ -90,12 +93,6 @@ impl RowGroupPruningTest {
9093
self
9194
}
9295

93-
fn with_limit_pruned_row_groups(mut self, pruned_by_limit: Option<usize>) -> Self {
94-
self.expected_limit_pruned_row_groups = pruned_by_limit;
95-
self
96-
}
97-
*/
98-
9996
// Set the expected pruned row groups by statistics
10097
fn with_pruned_by_stats(mut self, pruned_by_stats: Option<usize>) -> Self {
10198
self.expected_row_group_pruned_by_statistics = pruned_by_stats;
@@ -119,6 +116,11 @@ impl RowGroupPruningTest {
119116
self
120117
}
121118

119+
fn with_limit_pruned_row_groups(mut self, pruned_by_limit: Option<usize>) -> Self {
120+
self.expected_limit_pruned_row_groups = pruned_by_limit;
121+
self
122+
}
123+
122124
/// Set the number of expected rows from the output of this test
123125
fn with_expected_rows(mut self, rows: usize) -> Self {
124126
self.expected_rows = rows;
@@ -155,12 +157,12 @@ impl RowGroupPruningTest {
155157
);
156158
let bloom_filter_metrics = output.row_groups_bloom_filter();
157159
assert_eq!(
158-
bloom_filter_metrics.map(|(_pruned, matched)| matched),
160+
bloom_filter_metrics.as_ref().map(|pm| pm.total_matched()),
159161
self.expected_row_group_matched_by_bloom_filter,
160162
"mismatched row_groups_matched_bloom_filter",
161163
);
162164
assert_eq!(
163-
bloom_filter_metrics.map(|(pruned, _matched)| pruned),
165+
bloom_filter_metrics.map(|pm| pm.total_pruned()),
164166
self.expected_row_group_pruned_by_bloom_filter,
165167
"mismatched row_groups_pruned_bloom_filter",
166168
);
@@ -175,6 +177,64 @@ impl RowGroupPruningTest {
175177
);
176178
}
177179

180+
// Execute the test with the current configuration
181+
async fn test_row_group_prune_with_custom_data(
182+
self,
183+
schema: Arc<Schema>,
184+
batches: Vec<RecordBatch>,
185+
max_row_per_group: usize,
186+
) {
187+
let output = ContextWithParquet::with_custom_data(
188+
self.scenario,
189+
RowGroup(max_row_per_group),
190+
schema,
191+
batches,
192+
)
193+
.await
194+
.query(&self.query)
195+
.await;
196+
197+
println!("{}", output.description());
198+
assert_eq!(
199+
output.predicate_evaluation_errors(),
200+
self.expected_errors,
201+
"mismatched predicate_evaluation error"
202+
);
203+
assert_eq!(
204+
output.row_groups_matched_statistics(),
205+
self.expected_row_group_matched_by_statistics,
206+
"mismatched row_groups_matched_statistics",
207+
);
208+
assert_eq!(
209+
output.row_groups_fully_matched_statistics(),
210+
self.expected_row_group_fully_matched_by_statistics,
211+
"mismatched row_groups_fully_matched_statistics",
212+
);
213+
assert_eq!(
214+
output.row_groups_pruned_statistics(),
215+
self.expected_row_group_pruned_by_statistics,
216+
"mismatched row_groups_pruned_statistics",
217+
);
218+
assert_eq!(
219+
output.files_ranges_pruned_statistics(),
220+
self.expected_files_pruned_by_statistics,
221+
"mismatched files_ranges_pruned_statistics",
222+
);
223+
assert_eq!(
224+
output.limit_pruned_row_groups(),
225+
self.expected_limit_pruned_row_groups,
226+
"mismatched limit_pruned_row_groups",
227+
);
228+
assert_eq!(
229+
output.result_rows,
230+
self.expected_rows,
231+
"Expected {} rows, got {}: {}",
232+
output.result_rows,
233+
self.expected_rows,
234+
output.description(),
235+
);
236+
}
237+
178238
// Execute the test with the current configuration
179239
/*
180240
async fn test_row_group_prune_with_custom_data(
@@ -1723,7 +1783,6 @@ async fn test_bloom_filter_decimal_dict() {
17231783
.await;
17241784
}
17251785

1726-
/*
17271786
// Helper function to create a batch with a single Int32 column.
17281787
fn make_i32_batch(
17291788
name: &str,
@@ -1950,15 +2009,13 @@ async fn test_limit_pruning_exceeds_fully_matched() -> datafusion_common::error:
19502009
.with_scenario(Scenario::Int)
19512010
.with_query(query)
19522011
.with_expected_errors(Some(0))
1953-
.with_expected_rows(10) // Total: 1 + 3 + 4 + 1 = 9 (less than limit)
2012+
.with_expected_rows(10) // Total: 1 + 4 + 4 + 1 = 10
19542013
.with_pruned_files(Some(0))
19552014
.with_matched_by_stats(Some(4)) // RG0,1,2,3 matched
19562015
.with_fully_matched_by_stats(Some(2))
19572016
.with_pruned_by_stats(Some(1)) // RG4 pruned
19582017
.with_limit_pruned_row_groups(Some(0)) // No limit pruning since we need all RGs
19592018
.test_row_group_prune_with_custom_data(schema, batches, 4)
19602019
.await;
1961-
19622020
Ok(())
19632021
}
1964-
*/

0 commit comments

Comments
 (0)