Skip to content

Commit 02ef331

Browse files
committed
document it
1 parent 06cdf67 commit 02ef331

1 file changed

Lines changed: 43 additions & 1 deletion

File tree

crates/ty_python_semantic/src/types/constraints.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,20 +1184,59 @@ impl<'db> Node<'db> {
11841184
}
11851185
}
11861186

1187+
/// Combine an iterator of nodes into a single node using an associative operator.
1188+
///
1189+
/// Because the operator is associative, we don't have to combine the nodes left to right; we
1190+
/// can instead combine them in a "tree-like" way:
1191+
///
1192+
/// ```text
1193+
/// linear: (((((a ∨ b) ∨ c) ∨ d) ∨ e) ∨ f) ∨ g
1194+
/// tree: ((a ∨ b) ∨ (c ∨ d)) ∨ ((e ∨ f) ∨ g)
1195+
/// ```
1196+
///
1197+
/// We have to invoke the operator the same number of times. But BDD operators are often much
1198+
/// cheaper when the operands are small, and with the tree shape, many more of the invocations
1199+
/// are performed on small BDDs.
1200+
///
1201+
/// You must also provide the "zero" and "one" units of the operator. The "zero" is the value
1202+
/// that has no effect (`0 ∨ a = a`). It is returned if the iterator is empty. The "one" is the
1203+
/// value that saturates (`1 ∨ a = 1`). We use this to short-circuit; if any element BDD or any
1204+
/// intermediate result evaluates to "one", we can return early.
11871205
fn tree_fold(
11881206
db: &'db dyn Db,
11891207
nodes: impl Iterator<Item = Self>,
11901208
zero: Self,
11911209
one: Self,
11921210
mut combine: impl FnMut(Self, &'db dyn Db, Self) -> Self,
11931211
) -> Self {
1212+
// To implement the "linear" shape described above, we could collect the iterator elements
1213+
// into a vector, and then use the fold at the bottom of this method to combine the
1214+
// elements using the operator.
1215+
//
1216+
// To implement the "tree" shape, we also maintain a "depth" for each element of the
1217+
// vector, which indicates how many times the operator has been applied to the element.
1218+
// As we collect elements into the vector, we keep it capped at a length `O(log n)` of the
1219+
// number of elements seen so far. To do that, whenever the last two elements of the vector
1220+
// have the same depth, we apply the operator once to combine those two elements, adding
1221+
// the result back to the vector with an incremented depth. (That might let us combine the
1222+
// result with the _next_ intermediate result in the vector, and so on.)
1223+
//
1224+
// Walking through the example above, our vector ends up looking like:
1225+
//
1226+
// a/0
1227+
// a/0 b/0 => a∨b/1
1228+
// a∨b/1 c/0
1229+
// a∨b/1 c/0 d/0 => a∨b/1 c∨d/1 => a∨b∨c∨d/2
1230+
// a∨b∨c∨d/2 e/0
1231+
// a∨b∨c∨d/2 e/0 f/0 => a∨b∨c∨d/2 e∨f/1
1232+
// a∨b∨c∨d/2 e∨f/1 g/0
11941233
let mut accumulator: SmallVec<[(Node<'db>, u8); 8]> = SmallVec::default();
11951234
for node in nodes {
11961235
if node == one {
11971236
return one;
11981237
}
11991238

1200-
let (mut node, mut depth) = (node, 1);
1239+
let (mut node, mut depth) = (node, 0);
12011240
while accumulator
12021241
.last()
12031242
.is_some_and(|(_, existing)| *existing == depth)
@@ -1212,6 +1251,9 @@ impl<'db> Node<'db> {
12121251
accumulator.push((node, depth));
12131252
}
12141253

1254+
// At this point, we've consumed all of the iterator. The length of the accumulator will be
1255+
// the same as the number of 1 bits in the length of the iterator. We do a final fold to
1256+
// produce the overall result.
12151257
accumulator
12161258
.into_iter()
12171259
.fold(zero, |result, (node, _)| combine(result, db, node))

0 commit comments

Comments
 (0)