Skip to content

Commit 07e23bd

Browse files
committed
more tests including type mismatch, operators and NULL
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
1 parent 2e12048 commit 07e23bd

File tree

6 files changed

+122
-22
lines changed

6 files changed

+122
-22
lines changed

datafusion/core/tests/set_comparison.rs

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::array::Int32Array;
20+
use arrow::array::{Int32Array, StringArray};
2121
use arrow::datatypes::{DataType, Field, Schema};
2222
use arrow::record_batch::RecordBatch;
2323
use datafusion::prelude::SessionContext;
24-
use datafusion_common::{Result, assert_batches_eq};
24+
use datafusion_common::{Result, assert_batches_eq, assert_contains};
2525

2626
fn build_table(values: &[i32]) -> Result<RecordBatch> {
2727
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
@@ -82,3 +82,94 @@ async fn set_comparison_all_empty() -> Result<()> {
8282
);
8383
Ok(())
8484
}
85+
86+
#[tokio::test]
87+
async fn set_comparison_type_mismatch() -> Result<()> {
88+
let ctx = SessionContext::new();
89+
90+
ctx.register_batch("t", build_table(&[1])?)?;
91+
ctx.register_batch("strings", {
92+
let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
93+
let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")]))
94+
as Arc<dyn arrow::array::Array>;
95+
RecordBatch::try_new(schema, vec![array])?
96+
})?;
97+
98+
let df = ctx
99+
.sql("select v from t where v > any(select s from strings)")
100+
.await?;
101+
let err = df.collect().await.unwrap_err();
102+
assert_contains!(
103+
err.to_string(),
104+
"expr type Int32 can't cast to Utf8 in SetComparison"
105+
);
106+
Ok(())
107+
}
108+
109+
#[tokio::test]
110+
async fn set_comparison_multiple_operators() -> Result<()> {
111+
let ctx = SessionContext::new();
112+
113+
ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?;
114+
ctx.register_batch("s", build_table(&[2, 3])?)?;
115+
116+
let df = ctx
117+
.sql("select v from t where v = any(select v from s) order by v")
118+
.await?;
119+
let results = df.collect().await?;
120+
assert_batches_eq!(
121+
&["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",],
122+
&results
123+
);
124+
125+
let df = ctx
126+
.sql("select v from t where v != all(select v from s) order by v")
127+
.await?;
128+
let results = df.collect().await?;
129+
assert_batches_eq!(
130+
&["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",],
131+
&results
132+
);
133+
134+
let df = ctx
135+
.sql("select v from t where v >= all(select v from s) order by v")
136+
.await?;
137+
let results = df.collect().await?;
138+
assert_batches_eq!(
139+
&["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",],
140+
&results
141+
);
142+
143+
let df = ctx
144+
.sql("select v from t where v <= any(select v from s) order by v")
145+
.await?;
146+
let results = df.collect().await?;
147+
assert_batches_eq!(
148+
&[
149+
"+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+",
150+
],
151+
&results
152+
);
153+
Ok(())
154+
}
155+
156+
#[tokio::test]
157+
async fn set_comparison_null_semantics_all() -> Result<()> {
158+
let ctx = SessionContext::new();
159+
160+
ctx.register_batch("t", build_table(&[5])?)?;
161+
ctx.register_batch("s", {
162+
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
163+
let array = Arc::new(Int32Array::from(vec![Some(1), None]))
164+
as Arc<dyn arrow::array::Array>;
165+
RecordBatch::try_new(schema, vec![array])?
166+
})?;
167+
168+
let df = ctx
169+
.sql("select v from t where v != all(select v from s)")
170+
.await?;
171+
let results = df.collect().await?;
172+
let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
173+
assert_eq!(0, row_count);
174+
Ok(())
175+
}

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
381381
.data;
382382
let expr_type = expr.get_type(self.schema)?;
383383
let subquery_type = new_plan.schema().field(0).data_type();
384+
if (expr_type.is_numeric()
385+
&& is_utf8_or_utf8view_or_large_utf8(subquery_type))
386+
|| (subquery_type.is_numeric()
387+
&& is_utf8_or_utf8view_or_large_utf8(&expr_type))
388+
{
389+
return plan_err!(
390+
"expr type {expr_type} can't cast to {subquery_type} in SetComparison"
391+
);
392+
}
384393
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
385394
plan_datafusion_err!(
386395
"expr type {expr_type} can't cast to {subquery_type} in SetComparison"

datafusion/sql/src/unparser/expr.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,27 @@ use sqlparser::ast::{
2525
use std::sync::Arc;
2626
use std::vec;
2727

28-
use super::dialect::IntervalStyle;
2928
use super::Unparser;
29+
use super::dialect::IntervalStyle;
3030
use arrow::array::{
31+
ArrayRef, Date32Array, Date64Array, PrimitiveArray,
3132
types::{
3233
ArrowTemporalType, Time32MillisecondType, Time32SecondType,
3334
Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType,
3435
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
3536
},
36-
ArrayRef, Date32Array, Date64Array, PrimitiveArray,
3737
};
3838
use arrow::datatypes::{
39-
DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType,
39+
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
4040
};
4141
use arrow::util::display::array_value_to_string;
4242
use datafusion_common::{
43-
assert_eq_or_internal_err, assert_or_internal_err, internal_datafusion_err,
44-
internal_err, not_impl_err, plan_err, Column, Result, ScalarValue,
43+
Column, Result, ScalarValue, assert_eq_or_internal_err, assert_or_internal_err,
44+
internal_datafusion_err, internal_err, not_impl_err, plan_err,
4545
};
4646
use datafusion_expr::{
47-
expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction},
4847
Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast,
48+
expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction},
4949
};
5050
use sqlparser::ast::helpers::attached_token::AttachedToken;
5151
use sqlparser::tokenizer::Span;
@@ -1831,12 +1831,12 @@ mod tests {
18311831
use datafusion_common::{Spans, TableReference};
18321832
use datafusion_expr::expr::WildcardOptions;
18331833
use datafusion_expr::{
1834-
case, cast, col, cube, exists, grouping_set, interval_datetime_lit,
1835-
interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup,
1836-
table_scan, try_cast, when, ColumnarValue, ScalarFunctionArgs, ScalarUDF,
1837-
ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition,
1834+
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1835+
Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, cube, exists,
1836+
grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not,
1837+
not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when,
18381838
};
1839-
use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt};
1839+
use datafusion_expr::{ExprFunctionExt, interval_month_day_nano_lit};
18401840
use datafusion_functions::datetime::from_unixtime::FromUnixtimeFunc;
18411841
use datafusion_functions::expr_fn::{get_field, named_struct};
18421842
use datafusion_functions_aggregate::count::count_udaf;

datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
// under the License.
1717

1818
use crate::logical_plan::consumer::SubstraitConsumer;
19-
use datafusion::common::{substrait_datafusion_err, substrait_err, DFSchema, Spans};
19+
use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err};
2020
use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier};
2121
use datafusion::logical_expr::{Expr, Operator, Subquery};
2222
use std::sync::Arc;
2323
use substrait::proto::expression as substrait_expression;
24+
use substrait::proto::expression::subquery::SubqueryType;
2425
use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp};
2526
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
26-
use substrait::proto::expression::subquery::SubqueryType;
2727

2828
pub async fn from_subquery(
2929
consumer: &impl SubstraitConsumer,

datafusion/substrait/src/logical_plan/producer/substrait_producer.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ use crate::logical_plan::producer::{
2525
from_union, from_values, from_window, from_window_function, to_substrait_rel,
2626
to_substrait_rex,
2727
};
28-
use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue};
29-
use datafusion::execution::registry::SerializerRegistry;
28+
use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err};
3029
use datafusion::execution::SessionState;
30+
use datafusion::execution::registry::SerializerRegistry;
3131
use datafusion::logical_expr::expr::{
3232
Alias, InList, InSubquery, SetComparison, WindowFunction,
3333
};
3434
use datafusion::logical_expr::{
35-
expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr,
36-
Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort,
37-
SubqueryAlias, TableScan, TryCast, Union, Values, Window,
35+
Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension,
36+
Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, SubqueryAlias,
37+
TableScan, TryCast, Union, Values, Window, expr,
3838
};
3939
use pbjson_types::Any as ProtoAny;
4040
use substrait::proto::aggregate_rel::Measure;

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use std::mem::size_of_val;
2929

3030
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
3131
use datafusion::common::tree_node::Transformed;
32-
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef, Spans};
32+
use datafusion::common::{DFSchema, DFSchemaRef, Spans, not_impl_err, plan_err};
3333
use datafusion::error::Result;
3434
use datafusion::execution::registry::SerializerRegistry;
3535
use datafusion::execution::runtime_env::RuntimeEnv;
@@ -46,7 +46,7 @@ use std::hash::Hash;
4646
use std::sync::Arc;
4747
use substrait::proto::extensions::simple_extension_declaration::MappingType;
4848
use substrait::proto::rel::RelType;
49-
use substrait::proto::{plan_rel, Plan, Rel};
49+
use substrait::proto::{Plan, Rel, plan_rel};
5050

5151
#[derive(Debug)]
5252
struct MockSerializerRegistry;

0 commit comments

Comments
 (0)