@@ -84,7 +84,7 @@ use crate::types::{
8484 BoundTypeVarIdentity , BoundTypeVarInstance , IntersectionType , Type , TypeVarBoundOrConstraints ,
8585 UnionType , walk_bound_type_var_type,
8686} ;
87- use crate :: { Db , FxIndexMap , FxOrderSet } ;
87+ use crate :: { Db , FxIndexMap , FxIndexSet , FxOrderSet } ;
8888
8989/// An extension trait for building constraint sets from [`Option`] values.
9090pub ( crate ) trait OptionConstraintsExtension < T > {
@@ -147,27 +147,17 @@ where
147147 db : & ' db dyn Db ,
148148 mut f : impl FnMut ( T ) -> ConstraintSet < ' db > ,
149149 ) -> ConstraintSet < ' db > {
150- let mut result = ConstraintSet :: never ( ) ;
151- for child in self {
152- if result. union ( db, f ( child) ) . is_always_satisfied ( db) {
153- return result;
154- }
155- }
156- result
150+ let node = Node :: distributed_or ( db, self . map ( |element| f ( element) . node ) ) ;
151+ ConstraintSet { node }
157152 }
158153
159154 fn when_all < ' db > (
160155 self ,
161156 db : & ' db dyn Db ,
162157 mut f : impl FnMut ( T ) -> ConstraintSet < ' db > ,
163158 ) -> ConstraintSet < ' db > {
164- let mut result = ConstraintSet :: always ( ) ;
165- for child in self {
166- if result. intersect ( db, f ( child) ) . is_never_satisfied ( db) {
167- return result;
168- }
169- }
170- result
159+ let node = Node :: distributed_and ( db, self . map ( |element| f ( element) . node ) ) ;
160+ ConstraintSet { node }
171161 }
172162}
173163
@@ -1174,6 +1164,104 @@ impl<'db> Node<'db> {
11741164 }
11751165 }
11761166
1167+ /// Combine an iterator of nodes into a single node using an associative operator.
1168+ ///
1169+ /// Because the operator is associative, we don't have to combine the nodes left to right; we
1170+ /// can instead combine them in a "tree-like" way:
1171+ ///
1172+ /// ```text
1173+ /// linear: (((((a ∨ b) ∨ c) ∨ d) ∨ e) ∨ f) ∨ g
1174+ /// tree: ((a ∨ b) ∨ (c ∨ d)) ∨ ((e ∨ f) ∨ g)
1175+ /// ```
1176+ ///
1177+ /// We have to invoke the operator the same number of times. But BDD operators are often much
1178+ /// cheaper when the operands are small, and with the tree shape, many more of the invocations
1179+ /// are performed on small BDDs.
1180+ ///
1181+ /// You must also provide the "zero" and "one" units of the operator. The "zero" is the value
1182+ /// that has no effect (`0 ∨ a = a`). It is returned if the iterator is empty. The "one" is the
1183+ /// value that saturates (`1 ∨ a = 1`). We use this to short-circuit; if any element BDD or any
1184+ /// intermediate result evaluates to "one", we can return early.
1185+ fn tree_fold (
1186+ db : & ' db dyn Db ,
1187+ nodes : impl Iterator < Item = Self > ,
1188+ zero : Self ,
1189+ one : Self ,
1190+ mut combine : impl FnMut ( Self , & ' db dyn Db , Self ) -> Self ,
1191+ ) -> Self {
1192+ // To implement the "linear" shape described above, we could collect the iterator elements
1193+ // into a vector, and then use the fold at the bottom of this method to combine the
1194+ // elements using the operator.
1195+ //
1196+ // To implement the "tree" shape, we also maintain a "depth" for each element of the
1197+ // vector, which indicates how many times the operator has been applied to the element.
1198+ // As we collect elements into the vector, we keep it capped at a length `O(log n)` of the
1199+ // number of elements seen so far. To do that, whenever the last two elements of the vector
1200+ // have the same depth, we apply the operator once to combine those two elements, adding
1201+ // the result back to the vector with an incremented depth. (That might let us combine the
1202+ // result with the _next_ intermediate result in the vector, and so on.)
1203+ //
1204+ // Walking through the example above, our vector ends up looking like:
1205+ //
1206+ // a/0
1207+ // a/0 b/0 => ab/1
1208+ // ab/1 c/0
1209+ // ab/1 c/0 d/0 => ab/1 cd/1 => abcd/2
1210+ // abcd/2 e/0
1211+ // abcd/2 e/0 f/0 => abcd/2 ef/1
1212+ // abcd/2 ef/1 g/0
1213+ //
1214+ // We use a SmallVec for the accumulator so that we don't have to spill over to the heap
1215+ // until the iterator passes 256 elements.
1216+ let mut accumulator: SmallVec < [ ( Node < ' db > , u8 ) ; 8 ] > = SmallVec :: default ( ) ;
1217+ for node in nodes {
1218+ if node == one {
1219+ return one;
1220+ }
1221+
1222+ let ( mut node, mut depth) = ( node, 0 ) ;
1223+ while accumulator
1224+ . last ( )
1225+ . is_some_and ( |( _, existing) | * existing == depth)
1226+ {
1227+ let ( existing, _) = accumulator. pop ( ) . expect ( "accumulator should not be empty" ) ;
1228+ node = combine ( existing, db, node) ;
1229+ if node == one {
1230+ return one;
1231+ }
1232+ depth += 1 ;
1233+ }
1234+ accumulator. push ( ( node, depth) ) ;
1235+ }
1236+
1237+ // At this point, we've consumed all of the iterator. The length of the accumulator will be
1238+ // the same as the number of 1 bits in the length of the iterator. We do a final fold to
1239+ // produce the overall result.
1240+ accumulator
1241+ . into_iter ( )
1242+ . fold ( zero, |result, ( node, _) | combine ( result, db, node) )
1243+ }
1244+
1245+ fn distributed_or ( db : & ' db dyn Db , nodes : impl Iterator < Item = Node < ' db > > ) -> Self {
1246+ Self :: tree_fold (
1247+ db,
1248+ nodes,
1249+ Node :: AlwaysFalse ,
1250+ Node :: AlwaysTrue ,
1251+ Self :: or_with_offset,
1252+ )
1253+ }
1254+
1255+ fn distributed_and ( db : & ' db dyn Db , nodes : impl Iterator < Item = Node < ' db > > ) -> Self {
1256+ Self :: tree_fold (
1257+ db,
1258+ nodes,
1259+ Node :: AlwaysTrue ,
1260+ Node :: AlwaysFalse ,
1261+ Self :: and_with_offset,
1262+ )
1263+ }
1264+
11771265 /// Returns the `and` or intersection of two BDDs.
11781266 ///
11791267 /// In the result, `self` will appear before `other` according to the `source_order` of the BDD
@@ -1785,56 +1873,66 @@ impl<'db> Node<'db> {
17851873 db : & ' db dyn Db ,
17861874 node : Node < ' db > ,
17871875 prefix : & ' a dyn Display ,
1876+ seen : RefCell < FxIndexSet < InteriorNode < ' db > > > ,
17881877 }
17891878
1790- impl < ' a , ' db > DisplayNode < ' a , ' db > {
1791- fn new ( db : & ' db dyn Db , node : Node < ' db > , prefix : & ' a dyn Display ) -> Self {
1792- Self { db, node, prefix }
1879+ fn format_node < ' db > (
1880+ db : & ' db dyn Db ,
1881+ node : Node < ' db > ,
1882+ prefix : & dyn Display ,
1883+ seen : & RefCell < FxIndexSet < InteriorNode < ' db > > > ,
1884+ f : & mut std:: fmt:: Formatter < ' _ > ,
1885+ ) -> std:: fmt:: Result {
1886+ match node {
1887+ Node :: AlwaysTrue => write ! ( f, "always" ) ,
1888+ Node :: AlwaysFalse => write ! ( f, "never" ) ,
1889+ Node :: Interior ( interior) => {
1890+ let ( index, is_new) = seen. borrow_mut ( ) . insert_full ( interior) ;
1891+ if !is_new {
1892+ return write ! ( f, "<{index}> SHARED" ) ;
1893+ }
1894+ write ! (
1895+ f,
1896+ "<{index}> {} {}/{}" ,
1897+ interior. constraint( db) . display( db) ,
1898+ interior. source_order( db) ,
1899+ interior. max_source_order( db) ,
1900+ ) ?;
1901+ // Calling display_graph recursively here causes rustc to claim that the
1902+ // expect(unused) up above is unfulfilled!
1903+ write ! ( f, "\n {prefix}┡━₁ " , ) ?;
1904+ format_node (
1905+ db,
1906+ interior. if_true ( db) ,
1907+ & format_args ! ( "{prefix}│ " , ) ,
1908+ seen,
1909+ f,
1910+ ) ?;
1911+ write ! ( f, "\n {prefix}└─₀ " , ) ?;
1912+ format_node (
1913+ db,
1914+ interior. if_false ( db) ,
1915+ & format_args ! ( "{prefix} " , ) ,
1916+ seen,
1917+ f,
1918+ ) ?;
1919+ Ok ( ( ) )
1920+ }
17931921 }
17941922 }
17951923
17961924 impl Display for DisplayNode < ' _ , ' _ > {
17971925 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
1798- match self . node {
1799- Node :: AlwaysTrue => write ! ( f, "always" ) ,
1800- Node :: AlwaysFalse => write ! ( f, "never" ) ,
1801- Node :: Interior ( interior) => {
1802- write ! (
1803- f,
1804- "{} {}/{}" ,
1805- interior. constraint( self . db) . display( self . db) ,
1806- interior. source_order( self . db) ,
1807- interior. max_source_order( self . db) ,
1808- ) ?;
1809- // Calling display_graph recursively here causes rustc to claim that the
1810- // expect(unused) up above is unfulfilled!
1811- write ! (
1812- f,
1813- "\n {}┡━₁ {}" ,
1814- self . prefix,
1815- DisplayNode :: new(
1816- self . db,
1817- interior. if_true( self . db) ,
1818- & format_args!( "{}│ " , self . prefix)
1819- ) ,
1820- ) ?;
1821- write ! (
1822- f,
1823- "\n {}└─₀ {}" ,
1824- self . prefix,
1825- DisplayNode :: new(
1826- self . db,
1827- interior. if_false( self . db) ,
1828- & format_args!( "{} " , self . prefix)
1829- ) ,
1830- ) ?;
1831- Ok ( ( ) )
1832- }
1833- }
1926+ format_node ( self . db , self . node , self . prefix , & self . seen , f)
18341927 }
18351928 }
18361929
1837- DisplayNode :: new ( db, self , prefix)
1930+ DisplayNode {
1931+ db,
1932+ node : self ,
1933+ prefix,
1934+ seen : RefCell :: default ( ) ,
1935+ }
18381936 }
18391937}
18401938
@@ -4002,30 +4100,18 @@ mod tests {
40024100 #[ test]
40034101 fn test_display_graph_output ( ) {
40044102 let expected = indoc ! { r#"
4005- (U = bool) 2/4
4006- ┡━₁ (U = str) 1/4
4007- │ ┡━₁ (T = bool) 4/4
4008- │ │ ┡━₁ (T = str) 3/3
4103+ <0> (U = bool) 2/4
4104+ ┡━₁ <1> (U = str) 1/4
4105+ │ ┡━₁ <2> (T = bool) 4/4
4106+ │ │ ┡━₁ <3> (T = str) 3/3
40094107 │ │ │ ┡━₁ always
40104108 │ │ │ └─₀ always
4011- │ │ └─₀ (T = str) 3/3
4109+ │ │ └─₀ <4> (T = str) 3/3
40124110 │ │ ┡━₁ always
40134111 │ │ └─₀ never
4014- │ └─₀ (T = bool) 4/4
4015- │ ┡━₁ (T = str) 3/3
4016- │ │ ┡━₁ always
4017- │ │ └─₀ always
4018- │ └─₀ (T = str) 3/3
4019- │ ┡━₁ always
4020- │ └─₀ never
4021- └─₀ (U = str) 1/4
4022- ┡━₁ (T = bool) 4/4
4023- │ ┡━₁ (T = str) 3/3
4024- │ │ ┡━₁ always
4025- │ │ └─₀ always
4026- │ └─₀ (T = str) 3/3
4027- │ ┡━₁ always
4028- │ └─₀ never
4112+ │ └─₀ <2> SHARED
4113+ └─₀ <5> (U = str) 1/4
4114+ ┡━₁ <2> SHARED
40294115 └─₀ never
40304116 "# }
40314117 . trim_end ( ) ;
0 commit comments