diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index c128ee0e6f58d..e8cb6f3b484ef 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -91,6 +91,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::Cast { .. } | Expr::TryCast { .. } | Expr::BinaryExpr { .. } + | Expr::AnyExpr { .. } | Expr::Between { .. } | Expr::InList { .. } | Expr::GetIndexedField { .. } diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index d6cf4c08ba94e..99cb539fe19a2 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -119,6 +119,11 @@ impl ExprRewritable for Expr { op, right: rewrite_boxed(right, rewriter)?, }, + Expr::AnyExpr { left, op, right } => Expr::AnyExpr { + left: rewrite_boxed(left, rewriter)?, + op, + right: rewrite_boxed(right, rewriter)?, + }, Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), diff --git a/datafusion/core/src/logical_plan/expr_schema.rs b/datafusion/core/src/logical_plan/expr_schema.rs index 19ac56fe0f92e..1a8adbf042644 100644 --- a/datafusion/core/src/logical_plan/expr_schema.rs +++ b/datafusion/core/src/logical_plan/expr_schema.rs @@ -111,6 +111,7 @@ impl ExprSchemable for Expr { | Expr::IsNull(_) | Expr::Between { .. } | Expr::InList { .. } + | Expr::AnyExpr { .. } | Expr::IsNotNull(_) => Ok(DataType::Boolean), Expr::BinaryExpr { ref left, @@ -189,6 +190,11 @@ impl ExprSchemable for Expr { ref right, .. } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::AnyExpr { + ref left, + ref right, + .. + } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index 9a771fcfa35bf..ea17ad585f00d 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -116,6 +116,10 @@ impl ExprVisitable for Expr { let visitor = left.accept(visitor)?; right.accept(visitor) } + Expr::AnyExpr { left, right, .. } => { + let visitor = left.accept(visitor)?; + right.accept(visitor) + } Expr::Between { expr, low, high, .. } => { diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs index e19d525b1645e..f9feb02c2bc11 100644 --- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs @@ -423,6 +423,10 @@ impl ExprIdentifierVisitor<'_> { desc.push_str("BinaryExpr-"); desc.push_str(&op.to_string()); } + Expr::AnyExpr { op, .. } => { + desc.push_str("AnyExpr-"); + desc.push_str(&op.to_string()); + } Expr::Not(_) => { desc.push_str("Not-"); } diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs index 61a6fd33fa026..e4ae2c35cb5cb 100644 --- a/datafusion/core/src/optimizer/simplify_expressions.rs +++ b/datafusion/core/src/optimizer/simplify_expressions.rs @@ -385,6 +385,7 @@ impl<'a> ConstEvaluator<'a> { Expr::TableUDF { .. } => false, Expr::Literal(_) | Expr::BinaryExpr { .. } + | Expr::AnyExpr { .. } | Expr::Not(_) | Expr::IsNotNull(_) | Expr::IsNull(_) diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 1748f497d5403..504ebaa1bb100 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -73,6 +73,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { Expr::Alias(_, _) | Expr::Literal(_) | Expr::BinaryExpr { .. } + | Expr::AnyExpr { .. } | Expr::Not(_) | Expr::IsNotNull(_) | Expr::IsNull(_) @@ -305,6 +306,9 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::BinaryExpr { left, right, .. } => { Ok(vec![left.as_ref().to_owned(), right.as_ref().to_owned()]) } + Expr::AnyExpr { left, right, .. } => { + Ok(vec![left.as_ref().to_owned(), right.as_ref().to_owned()]) + } Expr::IsNull(expr) | Expr::IsNotNull(expr) | Expr::Cast { expr, .. } @@ -394,6 +398,11 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { op: *op, right: Box::new(expressions[1].clone()), }), + Expr::AnyExpr { op, .. } => Ok(Expr::AnyExpr { + left: Box::new(expressions[0].clone()), + op: *op, + right: Box::new(expressions[1].clone()), + }), Expr::IsNull(_) => Ok(Expr::IsNull(Box::new(expressions[0].clone()))), Expr::IsNotNull(_) => Ok(Expr::IsNotNull(Box::new(expressions[0].clone()))), Expr::ScalarFunction { fun, .. } => Ok(Expr::ScalarFunction { diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 910339714c490..c66c289e084b4 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -67,7 +67,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::{compute::can_cast_types, datatypes::DataType}; use async_trait::async_trait; use datafusion_common::OuterQueryCursor; -use datafusion_physical_expr::expressions::OuterColumn; +use datafusion_physical_expr::expressions::{any, OuterColumn}; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -112,6 +112,11 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let right = create_physical_name(right, false)?; Ok(format!("{} {:?} {}", left, op, right)) } + Expr::AnyExpr { left, op, right } => { + let left = create_physical_name(left, false)?; + let right = create_physical_name(right, false)?; + Ok(format!("{} {:?} ANY({})", left, op, right)) + } Expr::Case { expr, when_then_expr, @@ -1096,7 +1101,6 @@ pub fn create_physical_expr( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, create_physical_expr(key, input_dfschema, input_schema, execution_props)?, ))), - Expr::ScalarFunction { fun, args } => { let physical_args = args .iter() @@ -1172,6 +1176,21 @@ pub fn create_physical_expr( binary_expr } } + Expr::AnyExpr { left, op, right } => { + let lhs = create_physical_expr( + left, + input_dfschema, + input_schema, + execution_props, + )?; + let rhs = create_physical_expr( + right, + input_dfschema, + input_schema, + execution_props, + )?; + any(lhs, *op, rhs, input_schema) + } Expr::InList { expr, list, diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 96df9d2791de6..94b50234647b5 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -1477,6 +1477,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + fn parse_sql_binary_any( + &self, + left: SQLExpr, + op: BinaryOperator, + right: Box, + schema: &DFSchema, + ) -> Result { + let operator = match op { + BinaryOperator::Eq => Ok(Operator::Eq), + BinaryOperator::NotEq => Ok(Operator::NotEq), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported SQL ANY operator {:?}", + op + ))), + }?; + + Ok(Expr::AnyExpr { + left: Box::new(self.sql_expr_to_logical_expr(left, schema)?), + op: operator, + right: Box::new(self.sql_expr_to_logical_expr(*right, schema)?), + }) + } + fn parse_sql_binary_op( &self, left: SQLExpr, @@ -1484,6 +1507,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { right: SQLExpr, schema: &DFSchema, ) -> Result { + match right { + SQLExpr::AnyOp(any_expr) => { + return self.parse_sql_binary_any(left, op, any_expr, schema); + } + SQLExpr::AllOp(_) => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported SQL ALL operator {:?}", + right + ))); + } + _ => {} + }; + let operator = match op { BinaryOperator::Gt => Ok(Operator::Gt), BinaryOperator::GtEq => Ok(Operator::GtEq), diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index 2614664552e74..50fa085dd62c9 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -297,6 +297,11 @@ where op: *op, right: Box::new(clone_with_replacement(&**right, replacement_fn)?), }), + Expr::AnyExpr { left, right, op } => Ok(Expr::AnyExpr { + left: Box::new(clone_with_replacement(&**left, replacement_fn)?), + op: *op, + right: Box::new(clone_with_replacement(&**right, replacement_fn)?), + }), Expr::Case { expr: case_expr_opt, when_then_expr, diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 03e6a4527c3e9..a8edb47f28911 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -859,6 +859,34 @@ async fn test_extract_date_part() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_binary_any() -> Result<()> { + // = + // int64 + test_expression!("1 = ANY([1, 2])", "true"); + test_expression!("3 = ANY([1, 2])", "false"); + test_expression!("NULL = ANY([1, 2])", "NULL"); + // float + test_expression!("1.0 = ANY([1.0, 2.0])", "true"); + test_expression!("3.0 = ANY([1.0, 2.0])", "false"); + // utf8 + test_expression!("'a' = ANY(['a', 'b'])", "true"); + test_expression!("'c' = ANY(['a', 'b'])", "false"); + // bool + test_expression!("true = ANY([true, false])", "true"); + test_expression!("false = ANY([true, false])", "true"); + test_expression!("false = ANY([true, true])", "false"); + // <> + test_expression!("3 <> ANY([1, 2])", "true"); + test_expression!("1 <> ANY([1, 2])", "false"); + test_expression!("2 <> ANY([1, 2])", "false"); + test_expression!("NULL = ANY([1, 2])", "NULL"); + test_expression!("'c' <> ANY(['a', 'b'])", "true"); + test_expression!("'a' <> ANY(['a', 'b'])", "false"); + + Ok(()) +} + #[tokio::test] async fn test_in_list_scalar() -> Result<()> { test_expression!("'a' IN ('a','b')", "true"); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 5e624d005ed79..b501020119858 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -101,6 +101,15 @@ pub enum Expr { /// Right-hand side of the expression right: Box, }, + /// A binary expression such as "age > 21" + AnyExpr { + /// Left-hand side of the expression + left: Box, + /// The comparison operator + op: Operator, + /// Right-hand side of the expression + right: Box, + }, /// Negation of an expression. The expression's type must be a boolean to make sense. Not(Box), /// Whether an expression is not Null. This expression is never null. @@ -445,6 +454,9 @@ impl fmt::Debug for Expr { Expr::BinaryExpr { left, op, right } => { write!(f, "{:?} {} {:?}", left, op, right) } + Expr::AnyExpr { left, op, right } => { + write!(f, "{:?} {} ANY({:?})", left, op, right) + } Expr::Sort { expr, asc, @@ -587,6 +599,11 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { let right = create_name(right, input_schema)?; Ok(format!("{} {} {}", left, op, right)) } + Expr::AnyExpr { left, op, right } => { + let left = create_name(left, input_schema)?; + let right = create_name(right, input_schema)?; + Ok(format!("{} {} ANY({})", left, op, right)) + } Expr::Case { expr, when_then_expr, diff --git a/datafusion/physical-expr/src/expressions/any.rs b/datafusion/physical-expr/src/expressions/any.rs new file mode 100644 index 0000000000000..d3be7ee80ffe0 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/any.rs @@ -0,0 +1,612 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Any expression + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, + PrimitiveArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::datatypes::ArrowPrimitiveType; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; + +use crate::expressions::try_cast; +use crate::PhysicalExpr; +use arrow::array::*; + +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{ColumnarValue, Operator}; + +macro_rules! compare_op_scalar { + ($LEFT: expr, $LIST_VALUES:expr, $OP:expr, $LIST_VALUES_TYPE:ty, $LIST_FROM_SCALAR: expr) => {{ + let mut builder = BooleanBuilder::new($LEFT.len()); + + if $LIST_FROM_SCALAR { + for i in 0..$LEFT.len() { + if $LEFT.is_null(i) { + builder.append_null()?; + } else { + if $LIST_VALUES.is_null(0) { + builder.append_null()?; + } else { + builder.append_value($OP( + $LEFT.value(i), + $LIST_VALUES + .value(0) + .as_any() + .downcast_ref::<$LIST_VALUES_TYPE>() + .unwrap(), + ))?; + } + } + } + } else { + for i in 0..$LEFT.len() { + if $LEFT.is_null(i) { + builder.append_null()?; + } else { + if $LIST_VALUES.is_null(i) { + builder.append_null()?; + } else { + builder.append_value($OP( + $LEFT.value(i), + $LIST_VALUES + .value(i) + .as_any() + .downcast_ref::<$LIST_VALUES_TYPE>() + .unwrap(), + ))?; + } + } + } + } + + Ok(builder.finish()) + }}; +} + +macro_rules! make_primitive { + ($VALUES:expr, $IN_VALUES:expr, $NEGATED:expr, $TYPE:ident, $LIST_FROM_SCALAR: expr) => {{ + let left = $VALUES.as_any().downcast_ref::<$TYPE>().expect(&format!( + "Unable to downcast values to {}", + stringify!($TYPE) + )); + + if $NEGATED { + Ok(ColumnarValue::Array(Arc::new(neq_primitive( + left, + $IN_VALUES, + $LIST_FROM_SCALAR, + )?))) + } else { + Ok(ColumnarValue::Array(Arc::new(eq_primitive( + left, + $IN_VALUES, + $LIST_FROM_SCALAR, + )?))) + } + }}; +} + +fn eq_primitive( + array: &PrimitiveArray, + list: &ListArray, + list_from_scalar: bool, +) -> Result { + compare_op_scalar!( + array, + list, + |x, v: &PrimitiveArray| v.values().contains(&x), + PrimitiveArray, + list_from_scalar + ) +} + +fn neq_primitive( + array: &PrimitiveArray, + list: &ListArray, + list_from_scalar: bool, +) -> Result { + compare_op_scalar!( + array, + list, + |x, v: &PrimitiveArray| !v.values().contains(&x), + PrimitiveArray, + list_from_scalar + ) +} + +fn eq_bool( + array: &BooleanArray, + list: &ListArray, + list_from_scalar: bool, +) -> Result { + compare_op_scalar!( + array, + list, + |x, v: &BooleanArray| unsafe { + for i in 0..v.len() { + if v.value_unchecked(i) == x { + return true; + } + } + + return false; + }, + BooleanArray, + list_from_scalar + ) +} + +fn neq_bool( + array: &BooleanArray, + list: &ListArray, + list_from_scalar: bool, +) -> Result { + compare_op_scalar!( + array, + list, + |x, v: &BooleanArray| unsafe { + for i in 0..v.len() { + if v.value_unchecked(i) == x { + return false; + } + } + + return true; + }, + BooleanArray, + list_from_scalar + ) +} + +fn eq_utf8( + array: &GenericStringArray, + list: &ListArray, + list_from_scalar: bool, +) -> Result { + compare_op_scalar!( + array, + list, + |x, v: &GenericStringArray| unsafe { + for i in 0..v.len() { + if v.value_unchecked(i) == x { + return true; + } + } + + return false; + }, + GenericStringArray, + list_from_scalar + ) +} + +fn neq_utf8( + array: &GenericStringArray, + list: &ListArray, + list_from_scalar: bool, +) -> Result { + compare_op_scalar!( + array, + list, + |x, v: &GenericStringArray| unsafe { + for i in 0..v.len() { + if v.value_unchecked(i) == x { + return false; + } + } + + return true; + }, + GenericStringArray, + list_from_scalar + ) +} + +/// AnyExpr +#[derive(Debug)] +pub struct AnyExpr { + value: Arc, + op: Operator, + list: Arc, +} + +impl AnyExpr { + /// Create a new InList expression + pub fn new( + value: Arc, + op: Operator, + list: Arc, + ) -> Self { + Self { value, op, list } + } + + /// Compare for specific utf8 types + fn compare_utf8( + &self, + array: ArrayRef, + list: &ListArray, + negated: bool, + list_from_scalar: bool, + ) -> Result { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + if negated { + Ok(ColumnarValue::Array(Arc::new(neq_utf8( + array, + &list, + list_from_scalar, + )?))) + } else { + Ok(ColumnarValue::Array(Arc::new(eq_utf8( + array, + &list, + list_from_scalar, + )?))) + } + } + + /// Compare for specific utf8 types + fn compare_bool( + &self, + array: ArrayRef, + list: &ListArray, + negated: bool, + list_from_scalar: bool, + ) -> Result { + let array = array.as_any().downcast_ref::().unwrap(); + + if negated { + Ok(ColumnarValue::Array(Arc::new(neq_bool( + array, + &list, + list_from_scalar, + )?))) + } else { + Ok(ColumnarValue::Array(Arc::new(eq_bool( + array, + &list, + list_from_scalar, + )?))) + } + } +} + +impl std::fmt::Display for AnyExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} {} ANY({})", self.value, self.op, self.list) + } +} + +impl PhysicalExpr for AnyExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.value.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let value = match self.value.evaluate(batch)? { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array(), + }; + + let (list, list_from_scalar) = match self.list.evaluate(batch)? { + ColumnarValue::Array(array) => (array, false), + ColumnarValue::Scalar(scalar) => (scalar.to_array(), true), + }; + let as_list = list + .as_any() + .downcast_ref::() + .expect("Unable to downcast list to ListArray"); + + let negated = match self.op { + Operator::Eq => false, + Operator::NotEq => true, + op => { + return Err(DataFusionError::NotImplemented(format!( + "Operator for ANY expression, actual: {:?}", + op + ))); + } + }; + + match value.data_type() { + DataType::Float16 => { + make_primitive!(value, as_list, negated, Float16Array, list_from_scalar) + } + DataType::Float32 => { + make_primitive!(value, as_list, negated, Float32Array, list_from_scalar) + } + DataType::Float64 => { + make_primitive!(value, as_list, negated, Float64Array, list_from_scalar) + } + DataType::Int8 => { + make_primitive!(value, as_list, negated, Int8Array, list_from_scalar) + } + DataType::Int16 => { + make_primitive!(value, as_list, negated, Int16Array, list_from_scalar) + } + DataType::Int32 => { + make_primitive!(value, as_list, negated, Int32Array, list_from_scalar) + } + DataType::Int64 => { + make_primitive!(value, as_list, negated, Int64Array, list_from_scalar) + } + DataType::UInt8 => { + make_primitive!(value, as_list, negated, UInt8Array, list_from_scalar) + } + DataType::UInt16 => { + make_primitive!(value, as_list, negated, UInt16Array, list_from_scalar) + } + DataType::UInt32 => { + make_primitive!(value, as_list, negated, UInt32Array, list_from_scalar) + } + DataType::UInt64 => { + make_primitive!(value, as_list, negated, UInt64Array, list_from_scalar) + } + DataType::Boolean => { + self.compare_bool(value, as_list, negated, list_from_scalar) + } + DataType::Utf8 => { + self.compare_utf8::(value, as_list, negated, list_from_scalar) + } + DataType::LargeUtf8 => { + self.compare_utf8::(value, as_list, negated, list_from_scalar) + } + datatype => Result::Err(DataFusionError::NotImplemented(format!( + "AnyExpr does not support datatype {:?}.", + datatype + ))), + } + } +} + +/// return two physical expressions that are optionally coerced to a +/// common type that the binary operator supports. +fn any_cast( + value: Arc, + _op: &Operator, + list: Arc, + input_schema: &Schema, +) -> Result<(Arc, Arc)> { + let tmp = list.data_type(input_schema)?; + let list_type = match &tmp { + DataType::List(f) => f.data_type(), + _ => panic!("wtf"), + }; + + Ok((try_cast(value, input_schema, list_type.clone())?, list)) +} + +/// Creates an expression AnyExpr +pub fn any( + value: Arc, + op: Operator, + list: Arc, + input_schema: &Schema, +) -> Result> { + let (l, r) = any_cast(value, &op, list, input_schema)?; + Ok(Arc::new(AnyExpr::new(l, op, r))) +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Field; + + use super::*; + use crate::expressions::{col, lit}; + use datafusion_common::{Result, ScalarValue}; + + // applies the any expr to an input batch + macro_rules! execute_any { + ($BATCH:expr, $OP:expr, $EXPECTED:expr, $COL_A:expr, $COL_B:expr, $SCHEMA:expr) => {{ + let expr = any($COL_A, $OP, $COL_B, $SCHEMA).unwrap(); + let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + let expected = &BooleanArray::from($EXPECTED); + assert_eq!(expected, result); + }}; + } + + #[test] + fn any_int64_array_list() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, true); + let field_b = Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int64, true))), + true, + ); + + let schema = Schema::new(vec![field_a.clone(), field_b.clone()]); + let a = Int64Array::from(vec![Some(0), Some(3), None]); + let col_a = col("a", &schema)?; + + let values_builder = Int64Builder::new(3 * 3); + let mut builder = ListBuilder::new(values_builder); + + for _ in 0..3 { + builder.values().append_value(0).unwrap(); + builder.values().append_value(1).unwrap(); + builder.values().append_value(2).unwrap(); + builder.append(true).unwrap(); + } + + let b = builder.finish(); + let col_b = col("b", &schema)?; + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b)], + )?; + + execute_any!( + batch, + Operator::Eq, + vec![Some(true), Some(false), None], + col_a.clone(), + col_b.clone(), + &schema + ); + + Ok(()) + } + + // applies the any expr to an input batch and list + macro_rules! execute_any_with_list { + ($BATCH:expr, $LIST:expr, $OP:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ + let expr = any($COL, $OP, $LIST, $SCHEMA).unwrap(); + let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + let expected = &BooleanArray::from($EXPECTED); + assert_eq!(expected, result); + }}; + } + + #[test] + fn any_int64_scalar_list() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, true); + let schema = Schema::new(vec![field_a.clone()]); + let a = Int64Array::from(vec![Some(0), Some(3), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a = ANY (0, 1, 2)" + let list = lit(ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::Int64(Some(0)), + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)), + ])), + Box::new(DataType::Int64), + )); + + let schema = &Schema::new(vec![ + field_a, + Field::new( + "b", + DataType::List(Box::new(Field::new("d", DataType::Int64, true))), + true, + ), + ]); + execute_any_with_list!( + batch, + list, + Operator::Eq, + vec![Some(true), Some(false), None], + col_a.clone(), + schema + ); + + Ok(()) + } + + #[test] + fn any_utf8_scalar_list() -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, true); + let schema = Schema::new(vec![field_a.clone()]); + let a = StringArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a = ANY ('a', 'b', 'c')" + let list = lit(ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::Utf8(Some("a".to_string())), + ScalarValue::Utf8(Some("b".to_string())), + ScalarValue::Utf8(Some("c".to_string())), + ])), + Box::new(DataType::Utf8), + )); + + let schema = &Schema::new(vec![ + field_a, + Field::new( + "b", + DataType::List(Box::new(Field::new("d", DataType::Utf8, true))), + true, + ), + ]); + execute_any_with_list!( + batch, + list, + Operator::Eq, + vec![Some(true), Some(false), None], + col_a.clone(), + schema + ); + + Ok(()) + } + + #[test] + fn any_bool_scalar_list() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, true); + let schema = Schema::new(vec![field_a.clone()]); + let a = BooleanArray::from(vec![Some(true), Some(false), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + // expression: "a = ANY (true)" + let list = lit(ScalarValue::List( + Some(Box::new(vec![ScalarValue::Boolean(Some(true))])), + Box::new(DataType::Boolean), + )); + + let schema = &Schema::new(vec![ + field_a, + Field::new( + "b", + DataType::List(Box::new(Field::new("d", DataType::Boolean, true))), + true, + ), + ]); + execute_any_with_list!( + batch, + list, + Operator::Eq, + vec![Some(true), Some(false), None], + col_a.clone(), + schema + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 4a860621e17cb..7b13490631c3c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -37,6 +37,7 @@ mod lead_lag; mod literal; #[macro_use] mod min_max; +mod any; mod approx_median; mod correlation; mod covariance; @@ -83,6 +84,7 @@ pub use covariance::{ }; pub use cume_dist::cume_dist; +pub use any::any; pub use distinct_expressions::{DistinctArrayAgg, DistinctCount}; pub use get_indexed_field::GetIndexedFieldExpr; pub use in_list::{in_list, InListExpr};