Skip to content

Commit b47ab7c

Browse files
authored
Improve ergonomics of physical expr lit (#2828)
* Improve ergonomics of physical expr lit * Update usages of `lit` * Clippy
1 parent 88b88d4 commit b47ab7c

8 files changed

Lines changed: 252 additions & 651 deletions

File tree

datafusion/core/src/physical_optimizer/aggregate_statistics.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ mod tests {
525525
expressions::binary(
526526
expressions::col("a", &schema)?,
527527
Operator::Gt,
528-
expressions::lit(ScalarValue::from(1u32)),
528+
expressions::lit(1u32),
529529
&schema,
530530
)?,
531531
source,
@@ -568,7 +568,7 @@ mod tests {
568568
expressions::binary(
569569
expressions::col("a", &schema)?,
570570
Operator::Gt,
571-
expressions::lit(ScalarValue::from(1u32)),
571+
expressions::lit(1u32),
572572
&schema,
573573
)?,
574574
source,

datafusion/core/src/physical_plan/aggregates/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ mod tests {
718718
};
719719

720720
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Count::new(
721-
lit(ScalarValue::Int8(Some(1))),
721+
lit(1i8),
722722
"COUNT(1)".to_string(),
723723
DataType::Int64,
724724
))];

datafusion/core/src/physical_plan/filter.rs

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ mod tests {
239239
use crate::physical_plan::ExecutionPlan;
240240
use crate::physical_plan::{collect, with_new_children_if_necessary};
241241
use crate::prelude::SessionContext;
242-
use crate::scalar::ScalarValue;
243242
use crate::test;
244243
use crate::test_util;
245244
use datafusion_expr::Operator;
@@ -255,19 +254,9 @@ mod tests {
255254
let csv = test::scan_partitioned_csv(partitions)?;
256255

257256
let predicate: Arc<dyn PhysicalExpr> = binary(
258-
binary(
259-
col("c2", &schema)?,
260-
Operator::Gt,
261-
lit(ScalarValue::from(1u32)),
262-
&schema,
263-
)?,
257+
binary(col("c2", &schema)?, Operator::Gt, lit(1u32), &schema)?,
264258
Operator::And,
265-
binary(
266-
col("c2", &schema)?,
267-
Operator::Lt,
268-
lit(ScalarValue::from(4u32)),
269-
&schema,
270-
)?,
259+
binary(col("c2", &schema)?, Operator::Lt, lit(4u32), &schema)?,
271260
&schema,
272261
)?;
273262

@@ -292,12 +281,8 @@ mod tests {
292281
let partitions = 4;
293282
let input = test::scan_partitioned_csv(partitions)?;
294283

295-
let predicate: Arc<dyn PhysicalExpr> = binary(
296-
col("c2", &schema)?,
297-
Operator::Gt,
298-
lit(ScalarValue::from(1u32)),
299-
&schema,
300-
)?;
284+
let predicate: Arc<dyn PhysicalExpr> =
285+
binary(col("c2", &schema)?, Operator::Gt, lit(1u32), &schema)?;
301286

302287
let filter: Arc<dyn ExecutionPlan> =
303288
Arc::new(FilterExec::try_new(predicate, input.clone())?);

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,10 @@ mod tests {
316316
let schema = batch.schema();
317317

318318
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
319-
let when1 = lit(ScalarValue::Utf8(Some("foo".to_string())));
320-
let then1 = lit(ScalarValue::Int32(Some(123)));
321-
let when2 = lit(ScalarValue::Utf8(Some("bar".to_string())));
322-
let then2 = lit(ScalarValue::Int32(Some(456)));
319+
let when1 = lit("foo");
320+
let then1 = lit(123i32);
321+
let when2 = lit("bar");
322+
let then2 = lit(456i32);
323323

324324
let expr = case(
325325
Some(col("a", &schema)?),
@@ -345,11 +345,11 @@ mod tests {
345345
let schema = batch.schema();
346346

347347
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
348-
let when1 = lit(ScalarValue::Utf8(Some("foo".to_string())));
349-
let then1 = lit(ScalarValue::Int32(Some(123)));
350-
let when2 = lit(ScalarValue::Utf8(Some("bar".to_string())));
351-
let then2 = lit(ScalarValue::Int32(Some(456)));
352-
let else_value = lit(ScalarValue::Int32(Some(999)));
348+
let when1 = lit("foo");
349+
let then1 = lit(123i32);
350+
let when2 = lit("bar");
351+
let then2 = lit(456i32);
352+
let else_value = lit(999i32);
353353

354354
let expr = case(
355355
Some(col("a", &schema)?),
@@ -376,10 +376,10 @@ mod tests {
376376
let schema = batch.schema();
377377

378378
// CASE a when 0 THEN float64(null) ELSE 25.0 / cast(a, float64) END
379-
let when1 = lit(ScalarValue::Int32(Some(0)));
379+
let when1 = lit(0i32);
380380
let then1 = lit(ScalarValue::Float64(None));
381381
let else_value = binary(
382-
lit(ScalarValue::Float64(Some(25.0))),
382+
lit(25.0f64),
383383
Operator::Divide,
384384
cast(col("a", &schema)?, &batch.schema(), Float64)?,
385385
&batch.schema(),
@@ -412,17 +412,17 @@ mod tests {
412412
let when1 = binary(
413413
col("a", &schema)?,
414414
Operator::Eq,
415-
lit(ScalarValue::Utf8(Some("foo".to_string()))),
415+
lit("foo"),
416416
&batch.schema(),
417417
)?;
418-
let then1 = lit(ScalarValue::Int32(Some(123)));
418+
let then1 = lit(123i32);
419419
let when2 = binary(
420420
col("a", &schema)?,
421421
Operator::Eq,
422-
lit(ScalarValue::Utf8(Some("bar".to_string()))),
422+
lit("bar"),
423423
&batch.schema(),
424424
)?;
425-
let then2 = lit(ScalarValue::Int32(Some(456)));
425+
let then2 = lit(456i32);
426426

427427
let expr = case(None, &[(when1, then1), (when2, then2)], None)?;
428428
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
@@ -444,14 +444,9 @@ mod tests {
444444
let schema = batch.schema();
445445

446446
// CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE float64(null) END
447-
let when1 = binary(
448-
col("a", &schema)?,
449-
Operator::Gt,
450-
lit(ScalarValue::Int32(Some(0))),
451-
&batch.schema(),
452-
)?;
447+
let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
453448
let then1 = binary(
454-
lit(ScalarValue::Float64(Some(25.0))),
449+
lit(25.0f64),
455450
Operator::Divide,
456451
cast(col("a", &schema)?, &batch.schema(), Float64)?,
457452
&batch.schema(),
@@ -488,18 +483,18 @@ mod tests {
488483
let when1 = binary(
489484
col("a", &schema)?,
490485
Operator::Eq,
491-
lit(ScalarValue::Utf8(Some("foo".to_string()))),
486+
lit("foo"),
492487
&batch.schema(),
493488
)?;
494-
let then1 = lit(ScalarValue::Int32(Some(123)));
489+
let then1 = lit(123i32);
495490
let when2 = binary(
496491
col("a", &schema)?,
497492
Operator::Eq,
498-
lit(ScalarValue::Utf8(Some("bar".to_string()))),
493+
lit("bar"),
499494
&batch.schema(),
500495
)?;
501-
let then2 = lit(ScalarValue::Int32(Some(456)));
502-
let else_value = lit(ScalarValue::Int32(Some(999)));
496+
let then2 = lit(456i32);
497+
let else_value = lit(999i32);
503498

504499
let expr = case(None, &[(when1, then1), (when2, then2)], Some(else_value))?;
505500
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
@@ -525,11 +520,11 @@ mod tests {
525520
let when = binary(
526521
col("a", &schema)?,
527522
Operator::Eq,
528-
lit(ScalarValue::Utf8(Some("foo".to_string()))),
523+
lit("foo"),
529524
&batch.schema(),
530525
)?;
531-
let then = lit(ScalarValue::Float64(Some(123.3)));
532-
let else_value = lit(ScalarValue::Int32(Some(999)));
526+
let then = lit(123.3f64);
527+
let else_value = lit(999i32);
533528

534529
let expr = case(None, &[(when, then)], Some(else_value))?;
535530
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
@@ -555,7 +550,7 @@ mod tests {
555550
let when = binary(
556551
col("load4", &schema)?,
557552
Operator::Eq,
558-
lit(ScalarValue::Float64(Some(1.77))),
553+
lit(1.77f64),
559554
&batch.schema(),
560555
)?;
561556
let then = col("load4", &schema)?;
@@ -582,7 +577,7 @@ mod tests {
582577

583578
// SELECT CASE load4 WHEN 1.77 THEN load4 END
584579
let expr = col("load4", &schema)?;
585-
let when = lit(ScalarValue::Float64(Some(1.77)));
580+
let when = lit(1.77f64);
586581
let then = col("load4", &schema)?;
587582

588583
let expr = case(Some(expr), &[(when, then)], None)?;

datafusion/physical-expr/src/expressions/get_indexed_field.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ mod tests {
244244
#[test]
245245
fn get_indexed_field_invalid_scalar() -> Result<()> {
246246
let schema = list_schema("l");
247-
let expr = lit(ScalarValue::Utf8(Some("a".to_string())));
247+
let expr = lit("a");
248248
get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "Execution error: get indexed field is only possible on lists with int64 indexes or struct with utf8 indexes. Tried Utf8 with Int64(0) index")
249249
}
250250

0 commit comments

Comments
 (0)