Skip to content

Commit 5d12582

Browse files
committed
substrait support
Signed-off-by: Ruihang Xia <[email protected]>
1 parent c1faa88 commit 5d12582

File tree

5 files changed

+198
-17
lines changed

5 files changed

+198
-17
lines changed

datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
// under the License.
1717

1818
use crate::logical_plan::consumer::SubstraitConsumer;
19-
use datafusion::common::{substrait_err, DFSchema, Spans};
20-
use datafusion::logical_expr::expr::{Exists, InSubquery};
21-
use datafusion::logical_expr::{Expr, Subquery};
19+
use datafusion::common::{substrait_datafusion_err, substrait_err, DFSchema, Spans};
20+
use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier};
21+
use datafusion::logical_expr::{Expr, Operator, Subquery};
2222
use std::sync::Arc;
2323
use substrait::proto::expression as substrait_expression;
24+
use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp};
2425
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
2526
use substrait::proto::expression::subquery::SubqueryType;
2627

@@ -94,8 +95,53 @@ pub async fn from_subquery(
9495
),
9596
}
9697
}
97-
other_type => {
98-
substrait_err!("Subquery type {other_type:?} not implemented")
98+
SubqueryType::SetComparison(comparison) => {
99+
let left = comparison.left.as_ref().ok_or_else(|| {
100+
substrait_datafusion_err!("SetComparison requires a left expression")
101+
})?;
102+
let right = comparison.right.as_ref().ok_or_else(|| {
103+
substrait_datafusion_err!("SetComparison requires a right relation")
104+
})?;
105+
let reduction_op = match ReductionOp::try_from(comparison.reduction_op) {
106+
Ok(ReductionOp::Any) => SetQuantifier::Any,
107+
Ok(ReductionOp::All) => SetQuantifier::All,
108+
_ => {
109+
return substrait_err!(
110+
"Unsupported reduction op for SetComparison: {}",
111+
comparison.reduction_op
112+
)
113+
}
114+
};
115+
let comparison_op = match ComparisonOp::try_from(comparison.comparison_op)
116+
{
117+
Ok(ComparisonOp::Eq) => Operator::Eq,
118+
Ok(ComparisonOp::Ne) => Operator::NotEq,
119+
Ok(ComparisonOp::Lt) => Operator::Lt,
120+
Ok(ComparisonOp::Gt) => Operator::Gt,
121+
Ok(ComparisonOp::Le) => Operator::LtEq,
122+
Ok(ComparisonOp::Ge) => Operator::GtEq,
123+
_ => {
124+
return substrait_err!(
125+
"Unsupported comparison op for SetComparison: {}",
126+
comparison.comparison_op
127+
)
128+
}
129+
};
130+
131+
let left_expr = consumer.consume_expression(left, input_schema).await?;
132+
let plan = consumer.consume_rel(right).await?;
133+
let outer_ref_columns = plan.all_out_ref_exprs();
134+
135+
Ok(Expr::SetComparison(SetComparison::new(
136+
Box::new(left_expr),
137+
Subquery {
138+
subquery: Arc::new(plan),
139+
outer_ref_columns,
140+
spans: Spans::new(),
141+
},
142+
comparison_op,
143+
reduction_op,
144+
)))
99145
}
100146
},
101147
None => {

datafusion/substrait/src/logical_plan/producer/expr/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ pub fn to_substrait_rex(
141141
Expr::InList(expr) => producer.handle_in_list(expr, schema),
142142
Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
143143
Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema),
144-
Expr::SetComparison(expr) => {
145-
not_impl_err!("Cannot convert {expr:?} to Substrait")
146-
}
144+
Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema),
147145
Expr::ScalarSubquery(expr) => {
148146
not_impl_err!("Cannot convert {expr:?} to Substrait")
149147
}

datafusion/substrait/src/logical_plan/producer/expr/subquery.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
// under the License.
1717

1818
use crate::logical_plan::producer::SubstraitProducer;
19-
use datafusion::common::DFSchemaRef;
20-
use datafusion::logical_expr::expr::InSubquery;
19+
use datafusion::common::{substrait_err, DFSchemaRef};
20+
use datafusion::logical_expr::expr::{InSubquery, SetComparison, SetQuantifier};
21+
use datafusion::logical_expr::Operator;
22+
use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp};
2123
use substrait::proto::expression::subquery::InPredicate;
2224
use substrait::proto::expression::{RexType, ScalarFunction};
2325
use substrait::proto::function_argument::ArgType;
@@ -70,3 +72,53 @@ pub fn from_in_subquery(
7072
Ok(substrait_subquery)
7173
}
7274
}
75+
76+
fn comparison_op_to_proto(op: &Operator) -> datafusion::common::Result<ComparisonOp> {
77+
match op {
78+
Operator::Eq => Ok(ComparisonOp::Eq),
79+
Operator::NotEq => Ok(ComparisonOp::Ne),
80+
Operator::Lt => Ok(ComparisonOp::Lt),
81+
Operator::Gt => Ok(ComparisonOp::Gt),
82+
Operator::LtEq => Ok(ComparisonOp::Le),
83+
Operator::GtEq => Ok(ComparisonOp::Ge),
84+
_ => substrait_err!("Unsupported operator {op:?} for SetComparison subquery"),
85+
}
86+
}
87+
88+
fn reduction_op_to_proto(
89+
quantifier: &SetQuantifier,
90+
) -> datafusion::common::Result<ReductionOp> {
91+
match quantifier {
92+
SetQuantifier::Any => Ok(ReductionOp::Any),
93+
SetQuantifier::All => Ok(ReductionOp::All),
94+
}
95+
}
96+
97+
pub fn from_set_comparison(
98+
producer: &mut impl SubstraitProducer,
99+
set_comparison: &SetComparison,
100+
schema: &DFSchemaRef,
101+
) -> datafusion::common::Result<Expression> {
102+
let comparison_op = comparison_op_to_proto(&set_comparison.op)? as i32;
103+
let reduction_op = reduction_op_to_proto(&set_comparison.quantifier)? as i32;
104+
let left = producer.handle_expr(set_comparison.expr.as_ref(), schema)?;
105+
let subquery_plan =
106+
producer.handle_plan(set_comparison.subquery.subquery.as_ref())?;
107+
108+
Ok(Expression {
109+
rex_type: Some(RexType::Subquery(Box::new(
110+
substrait::proto::expression::Subquery {
111+
subquery_type: Some(
112+
substrait::proto::expression::subquery::SubqueryType::SetComparison(
113+
Box::new(substrait::proto::expression::subquery::SetComparison {
114+
reduction_op,
115+
comparison_op,
116+
left: Some(Box::new(left)),
117+
right: Some(subquery_plan),
118+
}),
119+
),
120+
),
121+
},
122+
))),
123+
})
124+
}

datafusion/substrait/src/logical_plan/producer/substrait_producer.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ use crate::logical_plan::producer::{
2020
from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr,
2121
from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter,
2222
from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal,
23-
from_projection, from_repartition, from_scalar_function, from_sort,
24-
from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union,
25-
from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex,
23+
from_projection, from_repartition, from_scalar_function, from_set_comparison,
24+
from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr,
25+
from_union, from_values, from_window, from_window_function, to_substrait_rel,
26+
to_substrait_rex,
2627
};
2728
use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue};
2829
use datafusion::execution::registry::SerializerRegistry;
2930
use datafusion::execution::SessionState;
30-
use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction};
31+
use datafusion::logical_expr::expr::{
32+
Alias, InList, InSubquery, SetComparison, WindowFunction,
33+
};
3134
use datafusion::logical_expr::{
3235
expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr,
3336
Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort,
@@ -359,6 +362,14 @@ pub trait SubstraitProducer: Send + Sync + Sized {
359362
) -> datafusion::common::Result<Expression> {
360363
from_in_subquery(self, in_subquery, schema)
361364
}
365+
366+
fn handle_set_comparison(
367+
&mut self,
368+
set_comparison: &SetComparison,
369+
schema: &DFSchemaRef,
370+
) -> datafusion::common::Result<Expression> {
371+
from_set_comparison(self, set_comparison, schema)
372+
}
362373
}
363374

364375
pub struct DefaultSubstraitProducer<'a> {

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ use std::mem::size_of_val;
2929

3030
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
3131
use datafusion::common::tree_node::Transformed;
32-
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
32+
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef, Spans};
3333
use datafusion::error::Result;
3434
use datafusion::execution::registry::SerializerRegistry;
3535
use datafusion::execution::runtime_env::RuntimeEnv;
3636
use datafusion::execution::session_state::SessionStateBuilder;
37+
use datafusion::logical_expr::expr::{SetComparison, SetQuantifier};
3738
use datafusion::logical_expr::{
38-
EmptyRelation, Extension, InvariantLevel, LogicalPlan, PartitionEvaluator,
39-
Repartition, UserDefinedLogicalNode, Values, Volatility,
39+
EmptyRelation, Extension, InvariantLevel, LogicalPlan, Operator, PartitionEvaluator,
40+
Repartition, Subquery, UserDefinedLogicalNode, Values, Volatility,
4041
};
4142
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
4243
use datafusion::prelude::*;
@@ -689,6 +690,29 @@ async fn roundtrip_exists_filter() -> Result<()> {
689690
Ok(())
690691
}
691692

693+
// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away)
694+
#[tokio::test]
695+
async fn roundtrip_set_comparison_any_substrait() -> Result<()> {
696+
let ctx = create_context().await?;
697+
let plan = build_set_comparison_plan(&ctx, SetQuantifier::Any, Operator::Gt).await?;
698+
let proto = to_substrait_plan(&plan, &ctx.state())?;
699+
let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?;
700+
assert_set_comparison_predicate(&roundtrip_plan, Operator::Gt, SetQuantifier::Any);
701+
Ok(())
702+
}
703+
704+
// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away)
705+
#[tokio::test]
706+
async fn roundtrip_set_comparison_all_substrait() -> Result<()> {
707+
let ctx = create_context().await?;
708+
let plan =
709+
build_set_comparison_plan(&ctx, SetQuantifier::All, Operator::NotEq).await?;
710+
let proto = to_substrait_plan(&plan, &ctx.state())?;
711+
let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?;
712+
assert_set_comparison_predicate(&roundtrip_plan, Operator::NotEq, SetQuantifier::All);
713+
Ok(())
714+
}
715+
692716
#[tokio::test]
693717
async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> {
694718
let plan = generate_plan_from_sql(
@@ -1865,6 +1889,56 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> {
18651889
Ok(())
18661890
}
18671891

1892+
async fn build_set_comparison_plan(
1893+
ctx: &SessionContext,
1894+
quantifier: SetQuantifier,
1895+
op: Operator,
1896+
) -> Result<LogicalPlan> {
1897+
let base_scan = ctx.table("data").await?.into_unoptimized_plan();
1898+
let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan();
1899+
let subquery_plan = LogicalPlanBuilder::from(subquery_scan)
1900+
.project(vec![col("data2.a")])?
1901+
.build()?;
1902+
let predicate = Expr::SetComparison(SetComparison::new(
1903+
Box::new(col("data.a")),
1904+
Subquery {
1905+
subquery: Arc::new(subquery_plan),
1906+
outer_ref_columns: vec![],
1907+
spans: Spans::new(),
1908+
},
1909+
op,
1910+
quantifier,
1911+
));
1912+
1913+
LogicalPlanBuilder::from(base_scan)
1914+
.filter(predicate)?
1915+
.project(vec![col("data.a")])?
1916+
.build()
1917+
}
1918+
1919+
fn assert_set_comparison_predicate(
1920+
plan: &LogicalPlan,
1921+
expected_op: Operator,
1922+
expected_quantifier: SetQuantifier,
1923+
) {
1924+
let predicate = match plan {
1925+
LogicalPlan::Projection(p) => match p.input.as_ref() {
1926+
LogicalPlan::Filter(filter) => &filter.predicate,
1927+
other => panic!("expected Filter inside Projection, got {other:?}"),
1928+
},
1929+
LogicalPlan::Filter(filter) => &filter.predicate,
1930+
other => panic!("expected Filter plan, got {other:?}"),
1931+
};
1932+
1933+
match predicate {
1934+
Expr::SetComparison(set_comparison) => {
1935+
assert_eq!(set_comparison.op, expected_op);
1936+
assert_eq!(set_comparison.quantifier, expected_quantifier);
1937+
}
1938+
other => panic!("expected SetComparison predicate, got {other:?}"),
1939+
}
1940+
}
1941+
18681942
async fn roundtrip_fill_na(sql: &str) -> Result<()> {
18691943
let ctx = create_context().await?;
18701944
let df = ctx.sql(sql).await?;

0 commit comments

Comments
 (0)