@@ -39,7 +39,7 @@ use datafusion_common::{
3939} ;
4040use datafusion_expr:: {
4141 BinaryExpr , Case , ColumnarValue , Expr , Like , Operator , Volatility , and,
42- binary:: BinaryTypeCoercer , lit, or,
42+ binary:: BinaryTypeCoercer , lit, or, preimage :: PreimageResult ,
4343} ;
4444use datafusion_expr:: { Cast , TryCast , simplify:: ExprSimplifyResult } ;
4545use datafusion_expr:: { expr:: ScalarFunction , interval_arithmetic:: NullableInterval } ;
@@ -51,14 +51,17 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP
5151
5252use super :: inlist_simplifier:: ShortenInListSimplifier ;
5353use super :: utils:: * ;
54- use crate :: analyzer:: type_coercion:: TypeCoercionRewriter ;
5554use crate :: simplify_expressions:: SimplifyContext ;
5655use crate :: simplify_expressions:: regex:: simplify_regex_expr;
5756use crate :: simplify_expressions:: unwrap_cast:: {
5857 is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
5958 is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
6059 unwrap_cast_in_comparison_for_binary,
6160} ;
61+ use crate :: {
62+ analyzer:: type_coercion:: TypeCoercionRewriter ,
63+ simplify_expressions:: udf_preimage:: rewrite_with_preimage,
64+ } ;
6265use datafusion_expr:: expr_rewriter:: rewrite_with_guarantees_map;
6366use datafusion_expr_common:: casts:: try_cast_literal_to_type;
6467use indexmap:: IndexSet ;
@@ -1969,12 +1972,85 @@ impl TreeNodeRewriter for Simplifier<'_> {
19691972 } ) )
19701973 }
19711974
1975+ // =======================================
1976+ // preimage_in_comparison
1977+ // =======================================
1978+ //
1979+ // For case:
1980+ // date_part('YEAR', expr) op literal
1981+ //
1982+ // For details see datafusion_expr::ScalarUDFImpl::preimage
1983+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
1984+ use datafusion_expr:: Operator :: * ;
1985+ let is_preimage_op = matches ! (
1986+ op,
1987+ Eq | NotEq
1988+ | Lt
1989+ | LtEq
1990+ | Gt
1991+ | GtEq
1992+ | IsDistinctFrom
1993+ | IsNotDistinctFrom
1994+ ) ;
1995+ if !is_preimage_op || is_null ( & right) {
1996+ return Ok ( Transformed :: no ( Expr :: BinaryExpr ( BinaryExpr {
1997+ left,
1998+ op,
1999+ right,
2000+ } ) ) ) ;
2001+ }
2002+
2003+ if let PreimageResult :: Range { interval, expr } =
2004+ get_preimage ( left. as_ref ( ) , right. as_ref ( ) , info) ?
2005+ {
2006+ rewrite_with_preimage ( * interval, op, expr) ?
2007+ } else if let Some ( swapped) = op. swap ( ) {
2008+ if let PreimageResult :: Range { interval, expr } =
2009+ get_preimage ( right. as_ref ( ) , left. as_ref ( ) , info) ?
2010+ {
2011+ rewrite_with_preimage ( * interval, swapped, expr) ?
2012+ } else {
2013+ Transformed :: no ( Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) )
2014+ }
2015+ } else {
2016+ Transformed :: no ( Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) )
2017+ }
2018+ }
2019+
19722020 // no additional rewrites possible
19732021 expr => Transformed :: no ( expr) ,
19742022 } )
19752023 }
19762024}
19772025
2026+ fn get_preimage (
2027+ left_expr : & Expr ,
2028+ right_expr : & Expr ,
2029+ info : & SimplifyContext ,
2030+ ) -> Result < PreimageResult > {
2031+ let Expr :: ScalarFunction ( ScalarFunction { func, args } ) = left_expr else {
2032+ return Ok ( PreimageResult :: None ) ;
2033+ } ;
2034+ if !is_literal_or_literal_cast ( right_expr) {
2035+ return Ok ( PreimageResult :: None ) ;
2036+ }
2037+ if func. signature ( ) . volatility != Volatility :: Immutable {
2038+ return Ok ( PreimageResult :: None ) ;
2039+ }
2040+ func. preimage ( args, right_expr, info)
2041+ }
2042+
2043+ fn is_literal_or_literal_cast ( expr : & Expr ) -> bool {
2044+ match expr {
2045+ Expr :: Literal ( _, _) => true ,
2046+ Expr :: Cast ( Cast { expr, .. } ) => matches ! ( expr. as_ref( ) , Expr :: Literal ( _, _) ) ,
2047+ Expr :: TryCast ( TryCast { expr, .. } ) => {
2048+ matches ! ( expr. as_ref( ) , Expr :: Literal ( _, _) )
2049+ }
2050+ _ => false ,
2051+ }
2052+ }
2053+
19782054fn as_string_scalar ( expr : & Expr ) -> Option < ( DataType , & Option < String > ) > {
19792055 match expr {
19802056 Expr :: Literal ( ScalarValue :: Utf8 ( s) , _) => Some ( ( DataType :: Utf8 , s) ) ,
0 commit comments