Skip to content

Commit 729c9d2

Browse files
committed
refactor TreeNode::rewrite()
1 parent ff7dfc3 commit 729c9d2

File tree

37 files changed

+355
-351
lines changed

37 files changed

+355
-351
lines changed

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {
9191

9292
impl MyAnalyzerRule {
9393
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
94-
plan.transform(&|plan| {
94+
plan.transform_up(&|plan| {
9595
Ok(match plan {
9696
LogicalPlan::Filter(filter) => {
9797
let predicate = Self::analyze_expr(filter.predicate.clone())?;
@@ -106,7 +106,7 @@ impl MyAnalyzerRule {
106106
}
107107

108108
fn analyze_expr(expr: Expr) -> Result<Expr> {
109-
expr.transform(&|expr| {
109+
expr.transform_up(&|expr| {
110110
// closure is invoked for all sub expressions
111111
Ok(match expr {
112112
Expr::Literal(ScalarValue::Int64(i)) => {
@@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {
161161

162162
/// use rewrite_expr to modify the expression tree.
163163
fn my_rewrite(expr: Expr) -> Result<Expr> {
164-
expr.transform(&|expr| {
164+
expr.transform_up(&|expr| {
165165
// closure is invoked for all sub expressions
166166
Ok(match expr {
167167
Expr::Between(Between {

datafusion/common/src/tree_node.rs

Lines changed: 104 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ use crate::Result;
3131
macro_rules! handle_tree_recursion {
3232
($EXPR:expr) => {
3333
match $EXPR {
34-
VisitRecursion::Continue => {}
34+
TreeNodeRecursion::Continue => {}
3535
// If the recursion should skip, do not apply to its children, let
3636
// the recursion continue:
37-
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
37+
TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue),
3838
// If the recursion should stop, do not apply to its children:
39-
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
39+
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
4040
}
4141
};
4242
}
@@ -58,10 +58,10 @@ pub trait TreeNode: Sized {
5858
///
5959
/// The `op` closure can be used to collect some info from the
6060
/// tree node or do some checking for the tree node.
61-
fn apply<F: FnMut(&Self) -> Result<VisitRecursion>>(
61+
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
6262
&self,
6363
op: &mut F,
64-
) -> Result<VisitRecursion> {
64+
) -> Result<TreeNodeRecursion> {
6565
handle_tree_recursion!(op(self)?);
6666
self.apply_children(&mut |node| node.apply(op))
6767
}
@@ -88,7 +88,7 @@ pub trait TreeNode: Sized {
8888
///
8989
/// If an Err result is returned, recursion is stopped immediately
9090
///
91-
/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
91+
/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
9292
/// children of that node will be visited, nor is post_visit
9393
/// called on that node. Details see [`TreeNodeVisitor`]
9494
///
@@ -97,20 +97,53 @@ pub trait TreeNode: Sized {
9797
fn visit<V: TreeNodeVisitor<N = Self>>(
9898
&self,
9999
visitor: &mut V,
100-
) -> Result<VisitRecursion> {
100+
) -> Result<TreeNodeRecursion> {
101101
handle_tree_recursion!(visitor.pre_visit(self)?);
102102
handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?);
103103
visitor.post_visit(self)
104104
}
105105

106-
/// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree.
107-
/// When `op` does not apply to a given node, it is left unchanged.
108-
/// The default tree traversal direction is transform_up(Postorder Traversal).
109-
fn transform<F>(self, op: &F) -> Result<Self>
106+
/// Transforms the tree using `f_down` while traversing the tree top-down
107+
/// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order).
108+
///
109+
/// E.g. for an tree such as:
110+
/// ```text
111+
/// ParentNode
112+
/// left: ChildNode1
113+
/// right: ChildNode2
114+
/// ```
115+
///
116+
/// The nodes are visited using the following order:
117+
/// ```text
118+
/// f_down(ParentNode)
119+
/// f_down(ChildNode1)
120+
/// f_up(ChildNode1)
121+
/// f_down(ChildNode2)
122+
/// f_up(ChildNode2)
123+
/// f_up(ParentNode)
124+
/// ```
125+
///
126+
/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
127+
///
128+
/// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately.
129+
fn transform<FD, FU>(self, f_down: &mut FD, f_up: &mut FU) -> Result<Self>
110130
where
111-
F: Fn(Self) -> Result<Transformed<Self>>,
131+
FD: FnMut(Self) -> Result<(Transformed<Self>, TreeNodeRecursion)>,
132+
FU: FnMut(Self) -> Result<Self>,
112133
{
113-
self.transform_up(op)
134+
let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?;
135+
match tnr {
136+
TreeNodeRecursion::Continue => {}
137+
// If the recursion should skip, do not apply to its children. And let the recursion continue
138+
TreeNodeRecursion::Skip => return Ok(new_node),
139+
// If the recursion should stop, do not apply to its children
140+
TreeNodeRecursion::Stop => {
141+
panic!("Stop can't be used in TreeNode::transform()")
142+
}
143+
}
144+
let node_with_new_children =
145+
new_node.map_children(|node| node.transform(f_down, f_up))?;
146+
f_up(node_with_new_children)
114147
}
115148

116149
/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
@@ -159,56 +192,50 @@ pub trait TreeNode: Sized {
159192
Ok(new_node)
160193
}
161194

162-
/// Transform the tree node using the given [TreeNodeRewriter]
163-
/// It performs a depth first walk of an node and its children.
195+
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
196+
/// recursively transforming [`TreeNode`]s.
164197
///
165-
/// For an node tree such as
198+
/// E.g. for an tree such as:
166199
/// ```text
167200
/// ParentNode
168201
/// left: ChildNode1
169202
/// right: ChildNode2
170203
/// ```
171204
///
172-
/// The nodes are visited using the following order
205+
/// The nodes are visited using the following order:
173206
/// ```text
174-
/// pre_visit(ParentNode)
175-
/// pre_visit(ChildNode1)
176-
/// mutate(ChildNode1)
177-
/// pre_visit(ChildNode2)
178-
/// mutate(ChildNode2)
179-
/// mutate(ParentNode)
207+
/// TreeNodeRewriter::f_down(ParentNode)
208+
/// TreeNodeRewriter::f_down(ChildNode1)
209+
/// TreeNodeRewriter::f_up(ChildNode1)
210+
/// TreeNodeRewriter::f_down(ChildNode2)
211+
/// TreeNodeRewriter::f_up(ChildNode2)
212+
/// TreeNodeRewriter::f_up(ParentNode)
180213
/// ```
181214
///
182-
/// If an Err result is returned, recursion is stopped immediately
183-
///
184-
/// If [`false`] is returned on a call to pre_visit, no
185-
/// children of that node will be visited, nor is mutate
186-
/// called on that node
215+
/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
187216
///
188-
/// If using the default [`TreeNodeRewriter::pre_visit`] which
189-
/// returns `true`, [`Self::transform`] should be preferred.
190-
fn rewrite<R: TreeNodeRewriter<N = Self>>(self, rewriter: &mut R) -> Result<Self> {
191-
let need_mutate = match rewriter.pre_visit(&self)? {
192-
RewriteRecursion::Mutate => return rewriter.mutate(self),
193-
RewriteRecursion::Stop => return Ok(self),
194-
RewriteRecursion::Continue => true,
195-
RewriteRecursion::Skip => false,
196-
};
197-
198-
let after_op_children = self.map_children(|node| node.rewrite(rewriter))?;
199-
200-
// now rewrite this node itself
201-
if need_mutate {
202-
rewriter.mutate(after_op_children)
203-
} else {
204-
Ok(after_op_children)
217+
/// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`],
218+
/// recursion is stopped immediately.
219+
fn rewrite<R: TreeNodeRewriter<Node = Self>>(self, rewriter: &mut R) -> Result<Self> {
220+
let (new_node, tnr) = rewriter.f_down(self)?;
221+
match tnr {
222+
TreeNodeRecursion::Continue => {}
223+
// If the recursion should skip, do not apply to its children. And let the recursion continue
224+
TreeNodeRecursion::Skip => return Ok(new_node),
225+
// If the recursion should stop, do not apply to its children
226+
TreeNodeRecursion::Stop => {
227+
panic!("Stop can't be used in TreeNode::rewrite()")
228+
}
205229
}
230+
let node_with_new_children =
231+
new_node.map_children(|node| node.rewrite(rewriter))?;
232+
rewriter.f_up(node_with_new_children)
206233
}
207234

208235
/// Apply the closure `F` to the node's children
209-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
236+
fn apply_children<F>(&self, op: &mut F) -> Result<TreeNodeRecursion>
210237
where
211-
F: FnMut(&Self) -> Result<VisitRecursion>;
238+
F: FnMut(&Self) -> Result<TreeNodeRecursion>;
212239

213240
/// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder)
214241
fn map_children<F>(self, transform: F) -> Result<Self>
@@ -231,69 +258,58 @@ pub trait TreeNode: Sized {
231258
/// If an [`Err`] result is returned, recursion is stopped
232259
/// immediately.
233260
///
234-
/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
261+
/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
235262
/// children of that tree node are visited, nor is post_visit
236263
/// called on that tree node
237264
///
238-
/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no
265+
/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no
239266
/// siblings of that tree node are visited, nor is post_visit
240267
/// called on its parent tree node
241268
///
242-
/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no
269+
/// If [`TreeNodeRecursion::Skip`] is returned on a call to pre_visit, no
243270
/// children of that tree node are visited.
244271
pub trait TreeNodeVisitor: Sized {
245272
/// The node type which is visitable.
246273
type N: TreeNode;
247274

248275
/// Invoked before any children of `node` are visited.
249-
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion>;
276+
fn pre_visit(&mut self, node: &Self::N) -> Result<TreeNodeRecursion>;
250277

251278
/// Invoked after all children of `node` are visited. Default
252279
/// implementation does nothing.
253-
fn post_visit(&mut self, _node: &Self::N) -> Result<VisitRecursion> {
254-
Ok(VisitRecursion::Continue)
280+
fn post_visit(&mut self, _node: &Self::N) -> Result<TreeNodeRecursion> {
281+
Ok(TreeNodeRecursion::Continue)
255282
}
256283
}
257284

258-
/// Trait for potentially recursively transform an [`TreeNode`] node
259-
/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is
260-
/// invoked recursively on all nodes of a tree.
285+
/// Trait for potentially recursively transform a [`TreeNode`] node tree.
261286
pub trait TreeNodeRewriter: Sized {
262287
/// The node type which is rewritable.
263-
type N: TreeNode;
288+
type Node: TreeNode;
264289

265-
/// Invoked before (Preorder) any children of `node` are rewritten /
266-
/// visited. Default implementation returns `Ok(Recursion::Continue)`
267-
fn pre_visit(&mut self, _node: &Self::N) -> Result<RewriteRecursion> {
268-
Ok(RewriteRecursion::Continue)
290+
/// Invoked while traversing down the tree before any children are rewritten /
291+
/// visited.
292+
/// Default implementation returns the node unmodified and continues recursion.
293+
fn f_down(&mut self, node: Self::Node) -> Result<(Self::Node, TreeNodeRecursion)> {
294+
Ok((node, TreeNodeRecursion::Continue))
269295
}
270296

271-
/// Invoked after (Postorder) all children of `node` have been mutated and
272-
/// returns a potentially modified node.
273-
fn mutate(&mut self, node: Self::N) -> Result<Self::N>;
274-
}
275-
276-
/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`].
277-
#[derive(Debug)]
278-
pub enum RewriteRecursion {
279-
/// Continue rewrite this node tree.
280-
Continue,
281-
/// Call 'op' immediately and return.
282-
Mutate,
283-
/// Do not rewrite the children of this node.
284-
Stop,
285-
/// Keep recursive but skip apply op on this node
286-
Skip,
297+
/// Invoked while traversing up the tree after all children have been rewritten /
298+
/// visited.
299+
/// Default implementation returns the node unmodified.
300+
fn f_up(&mut self, node: Self::Node) -> Result<Self::Node> {
301+
Ok(node)
302+
}
287303
}
288304

289-
/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`].
305+
/// Controls how [`TreeNode`] recursions should proceed.
290306
#[derive(Debug)]
291-
pub enum VisitRecursion {
292-
/// Continue the visit to this node tree.
307+
pub enum TreeNodeRecursion {
308+
/// Continue recursion with the next node.
293309
Continue,
294-
/// Keep recursive but skip applying op on the children
310+
/// Skip the current subtree.
295311
Skip,
296-
/// Stop the visit to this node tree.
312+
/// Stop recursion.
297313
Stop,
298314
}
299315

@@ -340,14 +356,14 @@ pub trait DynTreeNode {
340356
/// [`DynTreeNode`] (such as [`Arc<dyn PhysicalExpr>`])
341357
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
342358
/// Apply the closure `F` to the node's children
343-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
359+
fn apply_children<F>(&self, op: &mut F) -> Result<TreeNodeRecursion>
344360
where
345-
F: FnMut(&Self) -> Result<VisitRecursion>,
361+
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
346362
{
347363
for child in self.arc_children() {
348364
handle_tree_recursion!(op(&child)?)
349365
}
350-
Ok(VisitRecursion::Continue)
366+
Ok(TreeNodeRecursion::Continue)
351367
}
352368

353369
fn map_children<F>(self, transform: F) -> Result<Self>
@@ -382,14 +398,14 @@ pub trait ConcreteTreeNode: Sized {
382398

383399
impl<T: ConcreteTreeNode> TreeNode for T {
384400
/// Apply the closure `F` to the node's children
385-
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
401+
fn apply_children<F>(&self, op: &mut F) -> Result<TreeNodeRecursion>
386402
where
387-
F: FnMut(&Self) -> Result<VisitRecursion>,
403+
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
388404
{
389405
for child in self.children() {
390406
handle_tree_recursion!(op(child)?)
391407
}
392-
Ok(VisitRecursion::Continue)
408+
Ok(TreeNodeRecursion::Continue)
393409
}
394410

395411
fn map_children<F>(self, transform: F) -> Result<Self>

0 commit comments

Comments
 (0)