@@ -1184,20 +1184,62 @@ impl<'db> Node<'db> {
11841184 }
11851185 }
11861186
1187+ /// Combine an iterator of nodes into a single node using an associative operator.
1188+ ///
1189+ /// Because the operator is associative, we don't have to combine the nodes left to right; we
1190+ /// can instead combine them in a "tree-like" way:
1191+ ///
1192+ /// ```text
1193+ /// linear: (((((a ∨ b) ∨ c) ∨ d) ∨ e) ∨ f) ∨ g
1194+ /// tree: ((a ∨ b) ∨ (c ∨ d)) ∨ ((e ∨ f) ∨ g)
1195+ /// ```
1196+ ///
1197+ /// We have to invoke the operator the same number of times. But BDD operators are often much
1198+ /// cheaper when the operands are small, and with the tree shape, many more of the invocations
1199+ /// are performed on small BDDs.
1200+ ///
1201+ /// You must also provide the "zero" and "one" units of the operator. The "zero" is the value
1202+ /// that has no effect (`0 ∨ a = a`). It is returned if the iterator is empty. The "one" is the
1203+ /// value that saturates (`1 ∨ a = 1`). We use this to short-circuit; if any element BDD or any
1204+ /// intermediate result evaluates to "one", we can return early.
11871205 fn tree_fold (
11881206 db : & ' db dyn Db ,
11891207 nodes : impl Iterator < Item = Self > ,
11901208 zero : Self ,
11911209 one : Self ,
11921210 mut combine : impl FnMut ( Self , & ' db dyn Db , Self ) -> Self ,
11931211 ) -> Self {
1212+ // To implement the "linear" shape described above, we could collect the iterator elements
1213+ // into a vector, and then use the fold at the bottom of this method to combine the
1214+ // elements using the operator.
1215+ //
1216+ // To implement the "tree" shape, we also maintain a "depth" for each element of the
1217+ // vector, which indicates how many times the operator has been applied to the element.
1218+ // As we collect elements into the vector, we keep it capped at a length `O(log n)` of the
1219+ // number of elements seen so far. To do that, whenever the last two elements of the vector
1220+ // have the same depth, we apply the operator once to combine those two elements, adding
1221+ // the result back to the vector with an incremented depth. (That might let us combine the
1222+ // result with the _next_ intermediate result in the vector, and so on.)
1223+ //
1224+ // Walking through the example above, our vector ends up looking like:
1225+ //
1226+ // a/0
1227+ // a/0 b/0 => a∨b/1
1228+ // a∨b/1 c/0
1229+ // a∨b/1 c/0 d/0 => a∨b/1 c∨d/1 => a∨b∨c∨d/2
1230+ // a∨b∨c∨d/2 e/0
1231+ // a∨b∨c∨d/2 e/0 f/0 => a∨b∨c∨d/2 e∨f/1
1232+ // a∨b∨c∨d/2 e∨f/1 g/0
1233+ //
1234+ // We use a SmallVec for the accumulator so that we don't have to spill over to the heap
1235+ // until the iterator passes 256 elements.
11941236 let mut accumulator: SmallVec < [ ( Node < ' db > , u8 ) ; 8 ] > = SmallVec :: default ( ) ;
11951237 for node in nodes {
11961238 if node == one {
11971239 return one;
11981240 }
11991241
1200- let ( mut node, mut depth) = ( node, 1 ) ;
1242+ let ( mut node, mut depth) = ( node, 0 ) ;
12011243 while accumulator
12021244 . last ( )
12031245 . is_some_and ( |( _, existing) | * existing == depth)
@@ -1212,6 +1254,9 @@ impl<'db> Node<'db> {
12121254 accumulator. push ( ( node, depth) ) ;
12131255 }
12141256
1257+ // At this point, we've consumed all of the iterator. The length of the accumulator will be
1258+ // the same as the number of 1 bits in the length of the iterator. We do a final fold to
1259+ // produce the overall result.
12151260 accumulator
12161261 . into_iter ( )
12171262 . fold ( zero, |result, ( node, _) | combine ( result, db, node) )
0 commit comments