2020
2121__all__ = [
2222 "PassBase" ,
23+ "Sequential" ,
24+ "InPlacePass" ,
25+ "FunctionalPass" ,
2326 "PassManager" ,
2427 "PassResult" ,
2528 # Errors
@@ -68,14 +71,72 @@ class PassResult:
6871class PassBase (abc .ABC ):
6972 """Base class for all passes.
7073
71- Class attributes:
72- in_place: Whether the pass modifies the model in place.
74+
75+ ``in_place`` and ``changes_input`` properties and what they mean:
76+
77+ +------------+------------------+----------------------------+
78+ | | changes_inputs | not changes_inputs |
79+ +------------+------------------+----------------------------+
80+ | in_place | in place | Side-effect-only pass |
81+ +------------+------------------+----------------------------+
82+ | not | destructive | functional |
83+ | in_place | | |
84+ +------------+------------------+----------------------------+
7385 """
7486
75- in_place : bool = True
87+ @property
88+ @abc .abstractmethod
89+ def in_place (self ) -> bool :
90+ """Whether the pass modifies the model in place and returns it.
91+
92+ If True, the pass will return the same model object that was passed in.
93+ If False, the pass will return a new model object.
94+ """
95+ raise NotImplementedError
96+
97+ @property
98+ @abc .abstractmethod
99+ def changes_input (self ) -> bool :
100+ """Whether the pass modifies input model."""
101+ raise NotImplementedError
102+
103+ @property
104+ def destructive (self ) -> bool :
105+ """Whether the pass will destroy the input model when ``in_place=False``.
106+
107+ A pass is destructive if it is not in place and it modifies the input model.
108+ """
109+ return not self .in_place and self .changes_input
76110
77111 def __call__ (self , model : ir .Model ) -> PassResult :
78- return self .call (model )
112+ # Check preconditions
113+ try :
114+ self .requires (model )
115+ except PreconditionError :
116+ raise
117+ except Exception as e :
118+ raise PreconditionError (
119+ f"Pre-condition for pass '{ self .__class__ .__name__ } ' failed"
120+ ) from e
121+
122+ result = self .call (model )
123+
124+ # Check postconditions
125+ try :
126+ self .ensures (model )
127+ except PostconditionError :
128+ raise
129+ except Exception as e :
130+ raise PostconditionError (
131+ f"Post-condition for pass '{ self .__class__ .__name__ } ' failed"
132+ ) from e
133+
134+ if not isinstance (result , PassResult ):
135+ raise TypeError (
136+ f"The result of the pass '{ self .__class__ .__name__ } ' should be type PassResult. "
137+ "Please create one with ir.passes.PassResult()."
138+ )
139+ return result
79140
80141 @abc .abstractmethod
81142 def call (self , model : ir .Model ) -> PassResult :
@@ -97,76 +158,105 @@ def ensures(self, model: ir.Model) -> None:
97158 del model # Unused
98159
99160
100- class PassManager :
161+ class InPlacePass (PassBase ):
162+ """A pass that modifies the input model in place and returns it."""
163+
164+ @property
165+ def in_place (self ) -> bool :
166+ return True
167+
168+ @property
169+ def changes_input (self ) -> bool :
170+ return True
171+
172+
173+ class FunctionalPass (PassBase ):
174+ """A pass that returns a new model but does not modify the input model."""
175+
176+ @property
177+ def in_place (self ) -> bool :
178+ return False
179+
180+ @property
181+ def changes_input (self ) -> bool :
182+ return False
183+
184+
185+ class Sequential (PassBase ):
186+ """Run a sequence of passes in order."""
187+
188+ def __init__ (self , * passes : PassBase ):
189+ if not passes :
190+ raise ValueError ("Sequential must take at least one pass" )
191+ self .passes = passes
192+ self ._in_place = all (pass_ .in_place for pass_ in passes )
193+ # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,
194+ # or if it is not designed to be in-place but somehow changes the input (destructive),
195+ # this pass sequence will change inputs.
196+ self ._changes_input = self .passes [0 ].changes_input or self .passes [0 ].in_place
197+
198+ @property
199+ def in_place (self ) -> bool :
200+ return self ._in_place
201+
202+ @property
203+ def changes_input (self ) -> bool :
204+ return self ._changes_input
205+
206+ def call (self , model : ir .Model ) -> PassResult :
207+ modified = False
208+ for i , pass_ in enumerate (self .passes ):
209+ logger .debug ("Running the %s-th pass '%s'" , i , pass_ )
210+ try :
211+ pass_result = pass_ (model )
212+ except Exception as e :
213+ prev_pass_names = [str (p ) for p in self .passes [:i ]]
214+ raise PassError (
215+ f"An error occurred when running the '{ pass_ } ' pass after the "
216+ f"following passes: { prev_pass_names } "
217+ ) from e
218+
219+ model = pass_result .model
220+ modified = modified or pass_result .modified
221+
222+ return PassResult (model , modified )
223+
224+
225+ class PassManager (Sequential ):
101226 """Pass manager for the IR.
102227
103- The PassManager is a callable that runs a sequence of passes on a model.
228+ The PassManager is a Pass that runs a sequence of passes on a model.
104229
105230 Attributes:
106231 passes: The passes to run.
107- check_invariants: Whether to check invariants before and after each pass.
108232 steps: The number of times to run the passes.
233+ early_stop: Whether to stop running the passes if the graph stops changing.
109234 """
110235
111236 def __init__ (
112237 self ,
113238 passes : Sequence [PassBase ],
114- check_invariants : bool = False ,
115239 steps : int = 1 ,
240+ early_stop : bool = True ,
116241 ):
117242 # TODO(justinchuby): Implement constraints
118- self .passes = list (passes )
119- self .check_invariants = check_invariants
243+ super ().__init__ (* passes )
120244 self .steps = steps
245+ self .early_stop = early_stop
121246
122- def __call__ (self , model : ir .Model ) -> PassResult :
247+ def call (self , model : ir .Model ) -> PassResult :
123248 """Run the set of passes `steps` number of times or until the graph stops changing."""
124249 overall_modified = False
125250 for step in range (self .steps ):
126- step_result = self ._run_one_step (model , step )
251+ try :
252+ step_result = super ().__call__ (model )
253+ except Exception as e :
254+ raise PassError (f"An error occurred at step { step } " ) from e
127255 model = step_result .model
128256 modified = step_result .modified
129257 overall_modified = overall_modified or modified
130258 # If the graph no longer changes, then we can stop running these passes
131- if not modified :
259+ if not modified and self . early_stop :
132260 logger .info ("PassManager: No more graph changes detected after step %s" , step )
133261 break
134262 return PassResult (model , overall_modified )
135-
136- def _run_one_step (self , model : ir .Model , step : int ) -> PassResult :
137- modified = False
138- for i , pass_ in enumerate (self .passes ):
139- logger .debug ("Running the %s-th pass '%s', (step %s)" , i , pass_ , step )
140-
141- # 1. Check preconditions
142- if self .check_invariants :
143- try :
144- pass_ .requires (model )
145- except Exception as e :
146- raise PreconditionError (f"Pre-condition failed for { pass_ } " ) from e
147-
148- # 2. Run the pass
149- try :
150- pass_result = pass_ (model )
151- except Exception as e :
152- prev_pass_names = [str (p ) for p in self .passes [:i ]]
153- raise PassError (
154- f"An error occurred when running the '{ pass_ } ' pass after the "
155- f"following passes: { prev_pass_names } during step { step } "
156- ) from e
157- if not isinstance (pass_result , PassResult ):
158- raise TypeError (
159- f"The result of the pass { pass_ } should be type PassResult."
160- "Please create one with ir.passes.PassResult()."
161- )
162-
163- model = pass_result .model
164- modified = modified or pass_result .modified
165-
166- # 3. Check postconditions
167- if self .check_invariants :
168- try :
169- pass_ .ensures (model )
170- except Exception as e :
171- raise PostconditionError (f"Post-condition failed for { pass_ } " ) from e
172- return PassResult (model , modified )
0 commit comments