@@ -19,18 +19,16 @@ use std::any::Any;
1919use std:: sync:: Arc ;
2020
2121use crate :: strings:: make_and_append_view;
22- use crate :: utils:: make_scalar_function;
22+ use crate :: utils:: { make_scalar_function, utf8_to_str_type } ;
2323use arrow:: array:: {
24- Array , ArrayIter , ArrayRef , AsArray , Int64Array , NullBufferBuilder , StringArrayType ,
25- StringViewArray , StringViewBuilder ,
24+ Array , ArrayIter , ArrayRef , AsArray , GenericStringBuilder , Int64Array ,
25+ OffsetSizeTrait , StringArrayType , StringViewArray ,
2626} ;
27- use arrow:: buffer:: ScalarBuffer ;
2827use arrow:: datatypes:: DataType ;
28+ use arrow_buffer:: { NullBufferBuilder , ScalarBuffer } ;
2929use datafusion_common:: cast:: as_int64_array;
3030use datafusion_common:: { exec_err, plan_err, Result } ;
31- use datafusion_expr:: {
32- ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility ,
33- } ;
31+ use datafusion_expr:: { ColumnarValue , Documentation , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility } ;
3432use datafusion_macros:: user_doc;
3533
3634#[ user_doc(
@@ -90,15 +88,15 @@ impl ScalarUDFImpl for SubstrFunc {
9088 & self . signature
9189 }
9290
93- // `SubstrFunc` always generates `Utf8View` output for its efficiency.
94- fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
95- Ok ( DataType :: Utf8View )
91+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
92+ if arg_types[ 0 ] == DataType :: Utf8View {
93+ Ok ( DataType :: Utf8View )
94+ } else {
95+ utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
96+ }
9697 }
9798
98- fn invoke_with_args (
99- & self ,
100- args : datafusion_expr:: ScalarFunctionArgs ,
101- ) -> Result < ColumnarValue > {
99+ fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
102100 make_scalar_function ( substr, vec ! [ ] ) ( & args. args )
103101 }
104102
@@ -185,11 +183,11 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
185183 match args[ 0 ] . data_type ( ) {
186184 DataType :: Utf8 => {
187185 let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
188- string_substr :: < _ > ( string_array, & args[ 1 ..] )
186+ string_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
189187 }
190188 DataType :: LargeUtf8 => {
191189 let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
192- string_substr :: < _ > ( string_array, & args[ 1 ..] )
190+ string_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
193191 }
194192 DataType :: Utf8View => {
195193 let string_array = args[ 0 ] . as_string_view ( ) ;
@@ -425,9 +423,10 @@ fn string_view_substr(
425423 }
426424}
427425
428- fn string_substr < ' a , V > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
426+ fn string_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
429427where
430428 V : StringArrayType < ' a > ,
429+ T : OffsetSizeTrait ,
431430{
432431 let start_array = as_int64_array ( & args[ 0 ] ) ?;
433432 let count_array_opt = if args. len ( ) == 2 {
@@ -442,7 +441,7 @@ where
442441 match args. len ( ) {
443442 1 => {
444443 let iter = ArrayIter :: new ( string_array) ;
445- let mut result_builder = StringViewBuilder :: new ( ) ;
444+ let mut result_builder = GenericStringBuilder :: < T > :: new ( ) ;
446445 for ( string, start) in iter. zip ( start_array. iter ( ) ) {
447446 match ( string, start) {
448447 ( Some ( string) , Some ( start) ) => {
@@ -465,7 +464,7 @@ where
465464 2 => {
466465 let iter = ArrayIter :: new ( string_array) ;
467466 let count_array = count_array_opt. unwrap ( ) ;
468- let mut result_builder = StringViewBuilder :: new ( ) ;
467+ let mut result_builder = GenericStringBuilder :: < T > :: new ( ) ;
469468
470469 for ( ( string, start) , count) in
471470 iter. zip ( start_array. iter ( ) ) . zip ( count_array. iter ( ) )
@@ -507,8 +506,8 @@ where
507506
508507#[ cfg( test) ]
509508mod tests {
510- use arrow:: array:: { Array , StringViewArray } ;
511- use arrow:: datatypes:: DataType :: Utf8View ;
509+ use arrow:: array:: { Array , StringArray , StringViewArray } ;
510+ use arrow:: datatypes:: DataType :: { Utf8 , Utf8View } ;
512511
513512 use datafusion_common:: { exec_err, Result , ScalarValue } ;
514513 use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
@@ -618,8 +617,8 @@ mod tests {
618617 ] ,
619618 Ok ( Some ( "alphabet" ) ) ,
620619 & str ,
621- Utf8View ,
622- StringViewArray
620+ Utf8 ,
621+ StringArray
623622 ) ;
624623 test_function ! (
625624 SubstrFunc :: new( ) ,
@@ -629,8 +628,8 @@ mod tests {
629628 ] ,
630629 Ok ( Some ( "ésoj" ) ) ,
631630 & str ,
632- Utf8View ,
633- StringViewArray
631+ Utf8 ,
632+ StringArray
634633 ) ;
635634 test_function ! (
636635 SubstrFunc :: new( ) ,
@@ -640,8 +639,8 @@ mod tests {
640639 ] ,
641640 Ok ( Some ( "joséésoj" ) ) ,
642641 & str ,
643- Utf8View ,
644- StringViewArray
642+ Utf8 ,
643+ StringArray
645644 ) ;
646645 test_function ! (
647646 SubstrFunc :: new( ) ,
@@ -651,8 +650,8 @@ mod tests {
651650 ] ,
652651 Ok ( Some ( "alphabet" ) ) ,
653652 & str ,
654- Utf8View ,
655- StringViewArray
653+ Utf8 ,
654+ StringArray
656655 ) ;
657656 test_function ! (
658657 SubstrFunc :: new( ) ,
@@ -662,8 +661,8 @@ mod tests {
662661 ] ,
663662 Ok ( Some ( "lphabet" ) ) ,
664663 & str ,
665- Utf8View ,
666- StringViewArray
664+ Utf8 ,
665+ StringArray
667666 ) ;
668667 test_function ! (
669668 SubstrFunc :: new( ) ,
@@ -673,8 +672,8 @@ mod tests {
673672 ] ,
674673 Ok ( Some ( "phabet" ) ) ,
675674 & str ,
676- Utf8View ,
677- StringViewArray
675+ Utf8 ,
676+ StringArray
678677 ) ;
679678 test_function ! (
680679 SubstrFunc :: new( ) ,
@@ -684,8 +683,8 @@ mod tests {
684683 ] ,
685684 Ok ( Some ( "alphabet" ) ) ,
686685 & str ,
687- Utf8View ,
688- StringViewArray
686+ Utf8 ,
687+ StringArray
689688 ) ;
690689 test_function ! (
691690 SubstrFunc :: new( ) ,
@@ -695,8 +694,8 @@ mod tests {
695694 ] ,
696695 Ok ( Some ( "" ) ) ,
697696 & str ,
698- Utf8View ,
699- StringViewArray
697+ Utf8 ,
698+ StringArray
700699 ) ;
701700 test_function ! (
702701 SubstrFunc :: new( ) ,
@@ -706,8 +705,8 @@ mod tests {
706705 ] ,
707706 Ok ( None ) ,
708707 & str ,
709- Utf8View ,
710- StringViewArray
708+ Utf8 ,
709+ StringArray
711710 ) ;
712711 test_function ! (
713712 SubstrFunc :: new( ) ,
@@ -718,8 +717,8 @@ mod tests {
718717 ] ,
719718 Ok ( Some ( "ph" ) ) ,
720719 & str ,
721- Utf8View ,
722- StringViewArray
720+ Utf8 ,
721+ StringArray
723722 ) ;
724723 test_function ! (
725724 SubstrFunc :: new( ) ,
@@ -730,8 +729,8 @@ mod tests {
730729 ] ,
731730 Ok ( Some ( "phabet" ) ) ,
732731 & str ,
733- Utf8View ,
734- StringViewArray
732+ Utf8 ,
733+ StringArray
735734 ) ;
736735 test_function ! (
737736 SubstrFunc :: new( ) ,
@@ -742,8 +741,8 @@ mod tests {
742741 ] ,
743742 Ok ( Some ( "alph" ) ) ,
744743 & str ,
745- Utf8View ,
746- StringViewArray
744+ Utf8 ,
745+ StringArray
747746 ) ;
748747 // starting from 5 (10 + -5)
749748 test_function ! (
@@ -755,8 +754,8 @@ mod tests {
755754 ] ,
756755 Ok ( Some ( "alph" ) ) ,
757756 & str ,
758- Utf8View ,
759- StringViewArray
757+ Utf8 ,
758+ StringArray
760759 ) ;
761760 // starting from -1 (4 + -5)
762761 test_function ! (
@@ -768,8 +767,8 @@ mod tests {
768767 ] ,
769768 Ok ( Some ( "" ) ) ,
770769 & str ,
771- Utf8View ,
772- StringViewArray
770+ Utf8 ,
771+ StringArray
773772 ) ;
774773 // starting from 0 (5 + -5)
775774 test_function ! (
@@ -781,8 +780,8 @@ mod tests {
781780 ] ,
782781 Ok ( Some ( "" ) ) ,
783782 & str ,
784- Utf8View ,
785- StringViewArray
783+ Utf8 ,
784+ StringArray
786785 ) ;
787786 test_function ! (
788787 SubstrFunc :: new( ) ,
@@ -793,8 +792,8 @@ mod tests {
793792 ] ,
794793 Ok ( None ) ,
795794 & str ,
796- Utf8View ,
797- StringViewArray
795+ Utf8 ,
796+ StringArray
798797 ) ;
799798 test_function ! (
800799 SubstrFunc :: new( ) ,
@@ -805,8 +804,8 @@ mod tests {
805804 ] ,
806805 Ok ( None ) ,
807806 & str ,
808- Utf8View ,
809- StringViewArray
807+ Utf8 ,
808+ StringArray
810809 ) ;
811810 test_function ! (
812811 SubstrFunc :: new( ) ,
@@ -817,8 +816,8 @@ mod tests {
817816 ] ,
818817 exec_err!( "negative substring length not allowed: substr(<str>, 1, -1)" ) ,
819818 & str ,
820- Utf8View ,
821- StringViewArray
819+ Utf8 ,
820+ StringArray
822821 ) ;
823822 test_function ! (
824823 SubstrFunc :: new( ) ,
@@ -829,8 +828,8 @@ mod tests {
829828 ] ,
830829 Ok ( Some ( "és" ) ) ,
831830 & str ,
832- Utf8View ,
833- StringViewArray
831+ Utf8 ,
832+ StringArray
834833 ) ;
835834 #[ cfg( not( feature = "unicode_expressions" ) ) ]
836835 test_function ! (
@@ -843,8 +842,8 @@ mod tests {
843842 "function substr requires compilation with feature flag: unicode_expressions."
844843 ) ,
845844 & str ,
846- Utf8View ,
847- StringViewArray
845+ Utf8 ,
846+ StringArray
848847 ) ;
849848 test_function ! (
850849 SubstrFunc :: new( ) ,
@@ -854,8 +853,8 @@ mod tests {
854853 ] ,
855854 Ok ( Some ( "abc" ) ) ,
856855 & str ,
857- Utf8View ,
858- StringViewArray
856+ Utf8 ,
857+ StringArray
859858 ) ;
860859 test_function ! (
861860 SubstrFunc :: new( ) ,
@@ -866,8 +865,8 @@ mod tests {
866865 ] ,
867866 exec_err!( "negative overflow when calculating skip value" ) ,
868867 & str ,
869- Utf8View ,
870- StringViewArray
868+ Utf8 ,
869+ StringArray
871870 ) ;
872871
873872 Ok ( ( ) )
0 commit comments