Skip to content

Commit 2643fb0

Browse files
authored
[ty] Use distributed versions of AND and OR on constraint sets (#22614)
There are some pathological examples where we create a constraint set which is the AND or OR of several smaller constraint sets. For example, when calling a function with many overloads, where the argument is a typevar, we create an OR of the typevar specializing to a type compatible with the respective parameter of each overload. Most functions have a small number of overloads. But there are some examples of methods with 15-20 overloads (pydantic, numpy, our own auto-generated `__getitem__` for large tuple literals). For those cases, it is helpful to be more clever about how we construct the final result. Before, we would just step through the `Iterator` of elements and accumulate them into a result constraint set. That results in an `O(n)` number of calls to the underlying `and` or `or` operator — each of which might have to construct a large temporary BDD tree. AND and OR are both associative, so we can do better! We now invoke the operator in a "tree" shape (described in more detail in the doc comment). We still have to perform the same number of calls, but more of the calls operate on smaller BDDs, resulting in a much smaller amount of overall work.
1 parent 2838bc1 commit 2643fb0

1 file changed

Lines changed: 161 additions & 75 deletions

File tree

crates/ty_python_semantic/src/types/constraints.rs

Lines changed: 161 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ use crate::types::{
8484
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints,
8585
UnionType, walk_bound_type_var_type,
8686
};
87-
use crate::{Db, FxIndexMap, FxOrderSet};
87+
use crate::{Db, FxIndexMap, FxIndexSet, FxOrderSet};
8888

8989
/// An extension trait for building constraint sets from [`Option`] values.
9090
pub(crate) trait OptionConstraintsExtension<T> {
@@ -147,27 +147,17 @@ where
147147
db: &'db dyn Db,
148148
mut f: impl FnMut(T) -> ConstraintSet<'db>,
149149
) -> ConstraintSet<'db> {
150-
let mut result = ConstraintSet::never();
151-
for child in self {
152-
if result.union(db, f(child)).is_always_satisfied(db) {
153-
return result;
154-
}
155-
}
156-
result
150+
let node = Node::distributed_or(db, self.map(|element| f(element).node));
151+
ConstraintSet { node }
157152
}
158153

159154
fn when_all<'db>(
160155
self,
161156
db: &'db dyn Db,
162157
mut f: impl FnMut(T) -> ConstraintSet<'db>,
163158
) -> ConstraintSet<'db> {
164-
let mut result = ConstraintSet::always();
165-
for child in self {
166-
if result.intersect(db, f(child)).is_never_satisfied(db) {
167-
return result;
168-
}
169-
}
170-
result
159+
let node = Node::distributed_and(db, self.map(|element| f(element).node));
160+
ConstraintSet { node }
171161
}
172162
}
173163

@@ -1174,6 +1164,104 @@ impl<'db> Node<'db> {
11741164
}
11751165
}
11761166

1167+
/// Combine an iterator of nodes into a single node using an associative operator.
1168+
///
1169+
/// Because the operator is associative, we don't have to combine the nodes left to right; we
1170+
/// can instead combine them in a "tree-like" way:
1171+
///
1172+
/// ```text
1173+
/// linear: (((((a ∨ b) ∨ c) ∨ d) ∨ e) ∨ f) ∨ g
1174+
/// tree: ((a ∨ b) ∨ (c ∨ d)) ∨ ((e ∨ f) ∨ g)
1175+
/// ```
1176+
///
1177+
/// We have to invoke the operator the same number of times. But BDD operators are often much
1178+
/// cheaper when the operands are small, and with the tree shape, many more of the invocations
1179+
/// are performed on small BDDs.
1180+
///
1181+
/// You must also provide the "zero" and "one" units of the operator. The "zero" is the value
1182+
/// that has no effect (`0 ∨ a = a`). It is returned if the iterator is empty. The "one" is the
1183+
/// value that saturates (`1 ∨ a = 1`). We use this to short-circuit; if any element BDD or any
1184+
/// intermediate result evaluates to "one", we can return early.
1185+
fn tree_fold(
1186+
db: &'db dyn Db,
1187+
nodes: impl Iterator<Item = Self>,
1188+
zero: Self,
1189+
one: Self,
1190+
mut combine: impl FnMut(Self, &'db dyn Db, Self) -> Self,
1191+
) -> Self {
1192+
// To implement the "linear" shape described above, we could collect the iterator elements
1193+
// into a vector, and then use the fold at the bottom of this method to combine the
1194+
// elements using the operator.
1195+
//
1196+
// To implement the "tree" shape, we also maintain a "depth" for each element of the
1197+
// vector, which indicates how many times the operator has been applied to the element.
1198+
// As we collect elements into the vector, we keep it capped at a length `O(log n)` of the
1199+
// number of elements seen so far. To do that, whenever the last two elements of the vector
1200+
// have the same depth, we apply the operator once to combine those two elements, adding
1201+
// the result back to the vector with an incremented depth. (That might let us combine the
1202+
// result with the _next_ intermediate result in the vector, and so on.)
1203+
//
1204+
// Walking through the example above, our vector ends up looking like:
1205+
//
1206+
// a/0
1207+
// a/0 b/0 => ab/1
1208+
// ab/1 c/0
1209+
// ab/1 c/0 d/0 => ab/1 cd/1 => abcd/2
1210+
// abcd/2 e/0
1211+
// abcd/2 e/0 f/0 => abcd/2 ef/1
1212+
// abcd/2 ef/1 g/0
1213+
//
1214+
// We use a SmallVec for the accumulator so that we don't have to spill over to the heap
1215+
// until the iterator passes 256 elements.
1216+
let mut accumulator: SmallVec<[(Node<'db>, u8); 8]> = SmallVec::default();
1217+
for node in nodes {
1218+
if node == one {
1219+
return one;
1220+
}
1221+
1222+
let (mut node, mut depth) = (node, 0);
1223+
while accumulator
1224+
.last()
1225+
.is_some_and(|(_, existing)| *existing == depth)
1226+
{
1227+
let (existing, _) = accumulator.pop().expect("accumulator should not be empty");
1228+
node = combine(existing, db, node);
1229+
if node == one {
1230+
return one;
1231+
}
1232+
depth += 1;
1233+
}
1234+
accumulator.push((node, depth));
1235+
}
1236+
1237+
// At this point, we've consumed all of the iterator. The length of the accumulator will be
1238+
// the same as the number of 1 bits in the length of the iterator. We do a final fold to
1239+
// produce the overall result.
1240+
accumulator
1241+
.into_iter()
1242+
.fold(zero, |result, (node, _)| combine(result, db, node))
1243+
}
1244+
1245+
fn distributed_or(db: &'db dyn Db, nodes: impl Iterator<Item = Node<'db>>) -> Self {
1246+
Self::tree_fold(
1247+
db,
1248+
nodes,
1249+
Node::AlwaysFalse,
1250+
Node::AlwaysTrue,
1251+
Self::or_with_offset,
1252+
)
1253+
}
1254+
1255+
fn distributed_and(db: &'db dyn Db, nodes: impl Iterator<Item = Node<'db>>) -> Self {
1256+
Self::tree_fold(
1257+
db,
1258+
nodes,
1259+
Node::AlwaysTrue,
1260+
Node::AlwaysFalse,
1261+
Self::and_with_offset,
1262+
)
1263+
}
1264+
11771265
/// Returns the `and` or intersection of two BDDs.
11781266
///
11791267
/// In the result, `self` will appear before `other` according to the `source_order` of the BDD
@@ -1785,56 +1873,66 @@ impl<'db> Node<'db> {
17851873
db: &'db dyn Db,
17861874
node: Node<'db>,
17871875
prefix: &'a dyn Display,
1876+
seen: RefCell<FxIndexSet<InteriorNode<'db>>>,
17881877
}
17891878

1790-
impl<'a, 'db> DisplayNode<'a, 'db> {
1791-
fn new(db: &'db dyn Db, node: Node<'db>, prefix: &'a dyn Display) -> Self {
1792-
Self { db, node, prefix }
1879+
fn format_node<'db>(
1880+
db: &'db dyn Db,
1881+
node: Node<'db>,
1882+
prefix: &dyn Display,
1883+
seen: &RefCell<FxIndexSet<InteriorNode<'db>>>,
1884+
f: &mut std::fmt::Formatter<'_>,
1885+
) -> std::fmt::Result {
1886+
match node {
1887+
Node::AlwaysTrue => write!(f, "always"),
1888+
Node::AlwaysFalse => write!(f, "never"),
1889+
Node::Interior(interior) => {
1890+
let (index, is_new) = seen.borrow_mut().insert_full(interior);
1891+
if !is_new {
1892+
return write!(f, "<{index}> SHARED");
1893+
}
1894+
write!(
1895+
f,
1896+
"<{index}> {} {}/{}",
1897+
interior.constraint(db).display(db),
1898+
interior.source_order(db),
1899+
interior.max_source_order(db),
1900+
)?;
1901+
// Calling display_graph recursively here causes rustc to claim that the
1902+
// expect(unused) up above is unfulfilled!
1903+
write!(f, "\n{prefix}┡━₁ ",)?;
1904+
format_node(
1905+
db,
1906+
interior.if_true(db),
1907+
&format_args!("{prefix}│ ",),
1908+
seen,
1909+
f,
1910+
)?;
1911+
write!(f, "\n{prefix}└─₀ ",)?;
1912+
format_node(
1913+
db,
1914+
interior.if_false(db),
1915+
&format_args!("{prefix} ",),
1916+
seen,
1917+
f,
1918+
)?;
1919+
Ok(())
1920+
}
17931921
}
17941922
}
17951923

17961924
impl Display for DisplayNode<'_, '_> {
17971925
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1798-
match self.node {
1799-
Node::AlwaysTrue => write!(f, "always"),
1800-
Node::AlwaysFalse => write!(f, "never"),
1801-
Node::Interior(interior) => {
1802-
write!(
1803-
f,
1804-
"{} {}/{}",
1805-
interior.constraint(self.db).display(self.db),
1806-
interior.source_order(self.db),
1807-
interior.max_source_order(self.db),
1808-
)?;
1809-
// Calling display_graph recursively here causes rustc to claim that the
1810-
// expect(unused) up above is unfulfilled!
1811-
write!(
1812-
f,
1813-
"\n{}┡━₁ {}",
1814-
self.prefix,
1815-
DisplayNode::new(
1816-
self.db,
1817-
interior.if_true(self.db),
1818-
&format_args!("{}│ ", self.prefix)
1819-
),
1820-
)?;
1821-
write!(
1822-
f,
1823-
"\n{}└─₀ {}",
1824-
self.prefix,
1825-
DisplayNode::new(
1826-
self.db,
1827-
interior.if_false(self.db),
1828-
&format_args!("{} ", self.prefix)
1829-
),
1830-
)?;
1831-
Ok(())
1832-
}
1833-
}
1926+
format_node(self.db, self.node, self.prefix, &self.seen, f)
18341927
}
18351928
}
18361929

1837-
DisplayNode::new(db, self, prefix)
1930+
DisplayNode {
1931+
db,
1932+
node: self,
1933+
prefix,
1934+
seen: RefCell::default(),
1935+
}
18381936
}
18391937
}
18401938

@@ -4002,30 +4100,18 @@ mod tests {
40024100
#[test]
40034101
fn test_display_graph_output() {
40044102
let expected = indoc! {r#"
4005-
(U = bool) 2/4
4006-
┡━₁ (U = str) 1/4
4007-
│ ┡━₁ (T = bool) 4/4
4008-
│ │ ┡━₁ (T = str) 3/3
4103+
<0> (U = bool) 2/4
4104+
┡━₁ <1> (U = str) 1/4
4105+
│ ┡━₁ <2> (T = bool) 4/4
4106+
│ │ ┡━₁ <3> (T = str) 3/3
40094107
│ │ │ ┡━₁ always
40104108
│ │ │ └─₀ always
4011-
│ │ └─₀ (T = str) 3/3
4109+
│ │ └─₀ <4> (T = str) 3/3
40124110
│ │ ┡━₁ always
40134111
│ │ └─₀ never
4014-
│ └─₀ (T = bool) 4/4
4015-
│ ┡━₁ (T = str) 3/3
4016-
│ │ ┡━₁ always
4017-
│ │ └─₀ always
4018-
│ └─₀ (T = str) 3/3
4019-
│ ┡━₁ always
4020-
│ └─₀ never
4021-
└─₀ (U = str) 1/4
4022-
┡━₁ (T = bool) 4/4
4023-
│ ┡━₁ (T = str) 3/3
4024-
│ │ ┡━₁ always
4025-
│ │ └─₀ always
4026-
│ └─₀ (T = str) 3/3
4027-
│ ┡━₁ always
4028-
│ └─₀ never
4112+
│ └─₀ <2> SHARED
4113+
└─₀ <5> (U = str) 1/4
4114+
┡━₁ <2> SHARED
40294115
└─₀ never
40304116
"#}
40314117
.trim_end();

0 commit comments

Comments
 (0)