@@ -113,6 +113,18 @@ macro_rules! make_function_inputs2 {
113113 } )
114114 . collect:: <$ARRAY_TYPE>( )
115115 } } ;
116+ ( $ARG1: expr, $ARG2: expr, $NAME1: expr, $NAME2: expr, $ARRAY_TYPE1: ident, $ARRAY_TYPE2: ident, $FUNC: block) => { {
117+ let arg1 = downcast_arg!( $ARG1, $NAME1, $ARRAY_TYPE1) ;
118+ let arg2 = downcast_arg!( $ARG2, $NAME2, $ARRAY_TYPE2) ;
119+
120+ arg1. iter( )
121+ . zip( arg2. iter( ) )
122+ . map( |( a1, a2) | match ( a1, a2) {
123+ ( Some ( a1) , Some ( a2) ) => Some ( $FUNC( a1, a2. try_into( ) . ok( ) ?) ) ,
124+ _ => None ,
125+ } )
126+ . collect:: <$ARRAY_TYPE1>( )
127+ } } ;
116128}
117129
118130math_unary_function ! ( "sqrt" , sqrt) ;
@@ -124,7 +136,6 @@ math_unary_function!("acos", acos);
124136math_unary_function ! ( "atan" , atan) ;
125137math_unary_function ! ( "floor" , floor) ;
126138math_unary_function ! ( "ceil" , ceil) ;
127- math_unary_function ! ( "round" , round) ;
128139math_unary_function ! ( "trunc" , trunc) ;
129140math_unary_function ! ( "abs" , abs) ;
130141math_unary_function ! ( "signum" , signum) ;
@@ -149,6 +160,59 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
149160 Ok ( ColumnarValue :: Array ( Arc :: new ( array) ) )
150161}
151162
163+ /// Round SQL function
164+ pub fn round ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
165+ if args. len ( ) != 1 && args. len ( ) != 2 {
166+ return Err ( DataFusionError :: Internal ( format ! (
167+ "round function requires one or two arguments, got {}" ,
168+ args. len( )
169+ ) ) ) ;
170+ }
171+
172+ let mut decimal_places =
173+ & ( Arc :: new ( Int64Array :: from_value ( 0 , args[ 0 ] . len ( ) ) ) as ArrayRef ) ;
174+
175+ if args. len ( ) == 2 {
176+ decimal_places = & args[ 1 ] ;
177+ }
178+
179+ match args[ 0 ] . data_type ( ) {
180+ DataType :: Float64 => Ok ( Arc :: new ( make_function_inputs2 ! (
181+ & args[ 0 ] ,
182+ decimal_places,
183+ "value" ,
184+ "decimal_places" ,
185+ Float64Array ,
186+ Int64Array ,
187+ {
188+ |value: f64 , decimal_places: i64 | {
189+ ( value * 10.0_f64 . powi( decimal_places. try_into( ) . unwrap( ) ) ) . round( )
190+ / 10.0_f64 . powi( decimal_places. try_into( ) . unwrap( ) )
191+ }
192+ }
193+ ) ) as ArrayRef ) ,
194+
195+ DataType :: Float32 => Ok ( Arc :: new ( make_function_inputs2 ! (
196+ & args[ 0 ] ,
197+ decimal_places,
198+ "value" ,
199+ "decimal_places" ,
200+ Float32Array ,
201+ Int64Array ,
202+ {
203+ |value: f32 , decimal_places: i64 | {
204+ ( value * 10.0_f32 . powi( decimal_places. try_into( ) . unwrap( ) ) ) . round( )
205+ / 10.0_f32 . powi( decimal_places. try_into( ) . unwrap( ) )
206+ }
207+ }
208+ ) ) as ArrayRef ) ,
209+
210+ other => Err ( DataFusionError :: Internal ( format ! (
211+ "Unsupported data type {other:?} for function round"
212+ ) ) ) ,
213+ }
214+ }
215+
152216/// Power SQL function
153217pub fn power ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
154218 match args[ 0 ] . data_type ( ) {
@@ -365,4 +429,40 @@ mod tests {
365429 assert_eq ! ( floats. value( 2 ) , 4.0 ) ;
366430 assert_eq ! ( floats. value( 3 ) , 4.0 ) ;
367431 }
432+
433+ #[ test]
434+ fn test_round_f32 ( ) {
435+ let args: Vec < ArrayRef > = vec ! [
436+ Arc :: new( Float32Array :: from( vec![ 125.2345 ; 10 ] ) ) , // input
437+ Arc :: new( Int64Array :: from( vec![ 0 , 1 , 2 , 3 , 4 , 5 , -1 , -2 , -3 , -4 ] ) ) , // decimal_places
438+ ] ;
439+
440+ let result = round ( & args) . expect ( "failed to initialize function round" ) ;
441+ let floats =
442+ as_float32_array ( & result) . expect ( "failed to initialize function round" ) ;
443+
444+ let expected = Float32Array :: from ( vec ! [
445+ 125.0 , 125.2 , 125.23 , 125.235 , 125.2345 , 125.2345 , 130.0 , 100.0 , 0.0 , 0.0 ,
446+ ] ) ;
447+
448+ assert_eq ! ( floats, & expected) ;
449+ }
450+
451+ #[ test]
452+ fn test_round_f64 ( ) {
453+ let args: Vec < ArrayRef > = vec ! [
454+ Arc :: new( Float64Array :: from( vec![ 125.2345 ; 10 ] ) ) , // input
455+ Arc :: new( Int64Array :: from( vec![ 0 , 1 , 2 , 3 , 4 , 5 , -1 , -2 , -3 , -4 ] ) ) , // decimal_places
456+ ] ;
457+
458+ let result = round ( & args) . expect ( "failed to initialize function round" ) ;
459+ let floats =
460+ as_float64_array ( & result) . expect ( "failed to initialize function round" ) ;
461+
462+ let expected = Float64Array :: from ( vec ! [
463+ 125.0 , 125.2 , 125.23 , 125.235 , 125.2345 , 125.2345 , 130.0 , 100.0 , 0.0 , 0.0 ,
464+ ] ) ;
465+
466+ assert_eq ! ( floats, & expected) ;
467+ }
368468}
0 commit comments