@@ -31,12 +31,12 @@ use crate::Result;
3131macro_rules! handle_tree_recursion {
3232 ( $EXPR: expr) => {
3333 match $EXPR {
34- VisitRecursion :: Continue => { }
34+ TreeNodeRecursion :: Continue => { }
3535 // If the recursion should skip, do not apply to its children, let
3636 // the recursion continue:
37- VisitRecursion :: Skip => return Ok ( VisitRecursion :: Continue ) ,
37+ TreeNodeRecursion :: Skip => return Ok ( TreeNodeRecursion :: Continue ) ,
3838 // If the recursion should stop, do not apply to its children:
39- VisitRecursion :: Stop => return Ok ( VisitRecursion :: Stop ) ,
39+ TreeNodeRecursion :: Stop => return Ok ( TreeNodeRecursion :: Stop ) ,
4040 }
4141 } ;
4242}
@@ -58,10 +58,10 @@ pub trait TreeNode: Sized {
5858 ///
5959 /// The `op` closure can be used to collect some info from the
6060 /// tree node or do some checking for the tree node.
61- fn apply < F : FnMut ( & Self ) -> Result < VisitRecursion > > (
61+ fn apply < F : FnMut ( & Self ) -> Result < TreeNodeRecursion > > (
6262 & self ,
6363 op : & mut F ,
64- ) -> Result < VisitRecursion > {
64+ ) -> Result < TreeNodeRecursion > {
6565 handle_tree_recursion ! ( op( self ) ?) ;
6666 self . apply_children ( & mut |node| node. apply ( op) )
6767 }
@@ -88,7 +88,7 @@ pub trait TreeNode: Sized {
8888 ///
8989 /// If an Err result is returned, recursion is stopped immediately
9090 ///
91- /// If [`VisitRecursion ::Stop`] is returned on a call to pre_visit, no
91+ /// If [`TreeNodeRecursion ::Stop`] is returned on a call to pre_visit, no
9292 /// children of that node will be visited, nor is post_visit
9393 /// called on that node. Details see [`TreeNodeVisitor`]
9494 ///
@@ -97,20 +97,53 @@ pub trait TreeNode: Sized {
9797 fn visit < V : TreeNodeVisitor < N = Self > > (
9898 & self ,
9999 visitor : & mut V ,
100- ) -> Result < VisitRecursion > {
100+ ) -> Result < TreeNodeRecursion > {
101101 handle_tree_recursion ! ( visitor. pre_visit( self ) ?) ;
102102 handle_tree_recursion ! ( self . apply_children( & mut |node| node. visit( visitor) ) ?) ;
103103 visitor. post_visit ( self )
104104 }
105105
106- /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree.
107- /// When `op` does not apply to a given node, it is left unchanged.
108- /// The default tree traversal direction is transform_up(Postorder Traversal).
109- fn transform < F > ( self , op : & F ) -> Result < Self >
106+ /// Transforms the tree using `f_down` while traversing the tree top-down
107+ /// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order).
108+ ///
109+ /// E.g. for an tree such as:
110+ /// ```text
111+ /// ParentNode
112+ /// left: ChildNode1
113+ /// right: ChildNode2
114+ /// ```
115+ ///
116+ /// The nodes are visited using the following order:
117+ /// ```text
118+ /// f_down(ParentNode)
119+ /// f_down(ChildNode1)
120+ /// f_up(ChildNode1)
121+ /// f_down(ChildNode2)
122+ /// f_up(ChildNode2)
123+ /// f_up(ParentNode)
124+ /// ```
125+ ///
126+ /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
127+ ///
128+ /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately.
129+ fn transform < FD , FU > ( self , f_down : & mut FD , f_up : & mut FU ) -> Result < Self >
110130 where
111- F : Fn ( Self ) -> Result < Transformed < Self > > ,
131+ FD : FnMut ( Self ) -> Result < ( Transformed < Self > , TreeNodeRecursion ) > ,
132+ FU : FnMut ( Self ) -> Result < Self > ,
112133 {
113- self . transform_up ( op)
134+ let ( new_node, tnr) = f_down ( self ) . map ( |( t, tnr) | ( t. into ( ) , tnr) ) ?;
135+ match tnr {
136+ TreeNodeRecursion :: Continue => { }
137+ // If the recursion should skip, do not apply to its children. And let the recursion continue
138+ TreeNodeRecursion :: Skip => return Ok ( new_node) ,
139+ // If the recursion should stop, do not apply to its children
140+ TreeNodeRecursion :: Stop => {
141+ panic ! ( "Stop can't be used in TreeNode::transform()" )
142+ }
143+ }
144+ let node_with_new_children =
145+ new_node. map_children ( |node| node. transform ( f_down, f_up) ) ?;
146+ f_up ( node_with_new_children)
114147 }
115148
116149 /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
@@ -159,56 +192,50 @@ pub trait TreeNode: Sized {
159192 Ok ( new_node)
160193 }
161194
162- /// Transform the tree node using the given [TreeNodeRewriter]
163- /// It performs a depth first walk of an node and its children .
195+ /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
196+ /// recursively transforming [`TreeNode`]s .
164197 ///
165- /// For an node tree such as
198+ /// E.g. for an tree such as:
166199 /// ```text
167200 /// ParentNode
168201 /// left: ChildNode1
169202 /// right: ChildNode2
170203 /// ```
171204 ///
172- /// The nodes are visited using the following order
205+ /// The nodes are visited using the following order:
173206 /// ```text
174- /// pre_visit (ParentNode)
175- /// pre_visit (ChildNode1)
176- /// mutate (ChildNode1)
177- /// pre_visit (ChildNode2)
178- /// mutate (ChildNode2)
179- /// mutate (ParentNode)
207+ /// TreeNodeRewriter::f_down (ParentNode)
208+ /// TreeNodeRewriter::f_down (ChildNode1)
209+ /// TreeNodeRewriter::f_up (ChildNode1)
210+ /// TreeNodeRewriter::f_down (ChildNode2)
211+ /// TreeNodeRewriter::f_up (ChildNode2)
212+ /// TreeNodeRewriter::f_up (ParentNode)
180213 /// ```
181214 ///
182- /// If an Err result is returned, recursion is stopped immediately
183- ///
184- /// If [`false`] is returned on a call to pre_visit, no
185- /// children of that node will be visited, nor is mutate
186- /// called on that node
215+ /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
187216 ///
188- /// If using the default [`TreeNodeRewriter::pre_visit`] which
189- /// returns `true`, [`Self::transform`] should be preferred.
190- fn rewrite < R : TreeNodeRewriter < N = Self > > ( self , rewriter : & mut R ) -> Result < Self > {
191- let need_mutate = match rewriter. pre_visit ( & self ) ? {
192- RewriteRecursion :: Mutate => return rewriter. mutate ( self ) ,
193- RewriteRecursion :: Stop => return Ok ( self ) ,
194- RewriteRecursion :: Continue => true ,
195- RewriteRecursion :: Skip => false ,
196- } ;
197-
198- let after_op_children = self . map_children ( |node| node. rewrite ( rewriter) ) ?;
199-
200- // now rewrite this node itself
201- if need_mutate {
202- rewriter. mutate ( after_op_children)
203- } else {
204- Ok ( after_op_children)
217+ /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`],
218+ /// recursion is stopped immediately.
219+ fn rewrite < R : TreeNodeRewriter < Node = Self > > ( self , rewriter : & mut R ) -> Result < Self > {
220+ let ( new_node, tnr) = rewriter. f_down ( self ) ?;
221+ match tnr {
222+ TreeNodeRecursion :: Continue => { }
223+ // If the recursion should skip, do not apply to its children. And let the recursion continue
224+ TreeNodeRecursion :: Skip => return Ok ( new_node) ,
225+ // If the recursion should stop, do not apply to its children
226+ TreeNodeRecursion :: Stop => {
227+ panic ! ( "Stop can't be used in TreeNode::rewrite()" )
228+ }
205229 }
230+ let node_with_new_children =
231+ new_node. map_children ( |node| node. rewrite ( rewriter) ) ?;
232+ rewriter. f_up ( node_with_new_children)
206233 }
207234
208235 /// Apply the closure `F` to the node's children
209- fn apply_children < F > ( & self , op : & mut F ) -> Result < VisitRecursion >
236+ fn apply_children < F > ( & self , op : & mut F ) -> Result < TreeNodeRecursion >
210237 where
211- F : FnMut ( & Self ) -> Result < VisitRecursion > ;
238+ F : FnMut ( & Self ) -> Result < TreeNodeRecursion > ;
212239
213240 /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder)
214241 fn map_children < F > ( self , transform : F ) -> Result < Self >
@@ -231,69 +258,58 @@ pub trait TreeNode: Sized {
231258/// If an [`Err`] result is returned, recursion is stopped
232259/// immediately.
233260///
234- /// If [`VisitRecursion ::Stop`] is returned on a call to pre_visit, no
261+ /// If [`TreeNodeRecursion ::Stop`] is returned on a call to pre_visit, no
235262/// children of that tree node are visited, nor is post_visit
236263/// called on that tree node
237264///
238- /// If [`VisitRecursion ::Stop`] is returned on a call to post_visit, no
265+ /// If [`TreeNodeRecursion ::Stop`] is returned on a call to post_visit, no
239266/// siblings of that tree node are visited, nor is post_visit
240267/// called on its parent tree node
241268///
242- /// If [`VisitRecursion ::Skip`] is returned on a call to pre_visit, no
269+ /// If [`TreeNodeRecursion ::Skip`] is returned on a call to pre_visit, no
243270/// children of that tree node are visited.
244271pub trait TreeNodeVisitor : Sized {
245272 /// The node type which is visitable.
246273 type N : TreeNode ;
247274
248275 /// Invoked before any children of `node` are visited.
249- fn pre_visit ( & mut self , node : & Self :: N ) -> Result < VisitRecursion > ;
276+ fn pre_visit ( & mut self , node : & Self :: N ) -> Result < TreeNodeRecursion > ;
250277
251278 /// Invoked after all children of `node` are visited. Default
252279 /// implementation does nothing.
253- fn post_visit ( & mut self , _node : & Self :: N ) -> Result < VisitRecursion > {
254- Ok ( VisitRecursion :: Continue )
280+ fn post_visit ( & mut self , _node : & Self :: N ) -> Result < TreeNodeRecursion > {
281+ Ok ( TreeNodeRecursion :: Continue )
255282 }
256283}
257284
258- /// Trait for potentially recursively transform an [`TreeNode`] node
259- /// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is
260- /// invoked recursively on all nodes of a tree.
285+ /// Trait for potentially recursively transform a [`TreeNode`] node tree.
261286pub trait TreeNodeRewriter : Sized {
262287 /// The node type which is rewritable.
263- type N : TreeNode ;
288+ type Node : TreeNode ;
264289
265- /// Invoked before (Preorder) any children of `node` are rewritten /
266- /// visited. Default implementation returns `Ok(Recursion::Continue)`
267- fn pre_visit ( & mut self , _node : & Self :: N ) -> Result < RewriteRecursion > {
268- Ok ( RewriteRecursion :: Continue )
290+ /// Invoked while traversing down the tree before any children are rewritten /
291+ /// visited.
292+ /// Default implementation returns the node unmodified and continues recursion.
293+ fn f_down ( & mut self , node : Self :: Node ) -> Result < ( Self :: Node , TreeNodeRecursion ) > {
294+ Ok ( ( node, TreeNodeRecursion :: Continue ) )
269295 }
270296
271- /// Invoked after (Postorder) all children of `node` have been mutated and
272- /// returns a potentially modified node.
273- fn mutate ( & mut self , node : Self :: N ) -> Result < Self :: N > ;
274- }
275-
276- /// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`].
277- #[ derive( Debug ) ]
278- pub enum RewriteRecursion {
279- /// Continue rewrite this node tree.
280- Continue ,
281- /// Call 'op' immediately and return.
282- Mutate ,
283- /// Do not rewrite the children of this node.
284- Stop ,
285- /// Keep recursive but skip apply op on this node
286- Skip ,
297+ /// Invoked while traversing up the tree after all children have been rewritten /
298+ /// visited.
299+ /// Default implementation returns the node unmodified.
300+ fn f_up ( & mut self , node : Self :: Node ) -> Result < Self :: Node > {
301+ Ok ( node)
302+ }
287303}
288304
289- /// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`] .
305+ /// Controls how [`TreeNode`] recursions should proceed.
290306#[ derive( Debug ) ]
291- pub enum VisitRecursion {
292- /// Continue the visit to this node tree .
307+ pub enum TreeNodeRecursion {
308+ /// Continue recursion with the next node.
293309 Continue ,
294- /// Keep recursive but skip applying op on the children
310+ /// Skip the current subtree.
295311 Skip ,
296- /// Stop the visit to this node tree .
312+ /// Stop recursion .
297313 Stop ,
298314}
299315
@@ -340,14 +356,14 @@ pub trait DynTreeNode {
340356/// [`DynTreeNode`] (such as [`Arc<dyn PhysicalExpr>`])
341357impl < T : DynTreeNode + ?Sized > TreeNode for Arc < T > {
342358 /// Apply the closure `F` to the node's children
343- fn apply_children < F > ( & self , op : & mut F ) -> Result < VisitRecursion >
359+ fn apply_children < F > ( & self , op : & mut F ) -> Result < TreeNodeRecursion >
344360 where
345- F : FnMut ( & Self ) -> Result < VisitRecursion > ,
361+ F : FnMut ( & Self ) -> Result < TreeNodeRecursion > ,
346362 {
347363 for child in self . arc_children ( ) {
348364 handle_tree_recursion ! ( op( & child) ?)
349365 }
350- Ok ( VisitRecursion :: Continue )
366+ Ok ( TreeNodeRecursion :: Continue )
351367 }
352368
353369 fn map_children < F > ( self , transform : F ) -> Result < Self >
@@ -382,14 +398,14 @@ pub trait ConcreteTreeNode: Sized {
382398
383399impl < T : ConcreteTreeNode > TreeNode for T {
384400 /// Apply the closure `F` to the node's children
385- fn apply_children < F > ( & self , op : & mut F ) -> Result < VisitRecursion >
401+ fn apply_children < F > ( & self , op : & mut F ) -> Result < TreeNodeRecursion >
386402 where
387- F : FnMut ( & Self ) -> Result < VisitRecursion > ,
403+ F : FnMut ( & Self ) -> Result < TreeNodeRecursion > ,
388404 {
389405 for child in self . children ( ) {
390406 handle_tree_recursion ! ( op( child) ?)
391407 }
392- Ok ( VisitRecursion :: Continue )
408+ Ok ( TreeNodeRecursion :: Continue )
393409 }
394410
395411 fn map_children < F > ( self , transform : F ) -> Result < Self >
0 commit comments