@@ -1428,6 +1428,7 @@ def replace_pattern(new_pattern):
14281428 self .remove_nodes ,
14291429 self .graph_pre_visitor ,
14301430 self .graph_post_visitor ,
1431+ self .as_function ,
14311432 )
14321433
14331434 return [replace_pattern (p ) for p in self ._target_pattern .commute ()]
@@ -1509,21 +1510,23 @@ class RewriteRuleClassBase:
15091510 @classmethod
15101511 def rule (cls , * args , ** kwargs ):
15111512 instance = cls (* args , ** kwargs )
1512- setup = instance .setup if hasattr (instance , "setup" ) else None
1513- cleanup = instance .cleanup if hasattr (instance , "cleanup" ) else None
15141513 return RewriteRule (
15151514 instance .pattern ,
15161515 instance .rewrite ,
15171516 instance .check ,
15181517 name = instance .name ,
15191518 remove_nodes = instance .remove_nodes ,
1520- graph_pre_visitor = setup ,
1521- graph_post_visitor = cleanup ,
1519+ graph_pre_visitor = instance .setup ,
1520+ graph_post_visitor = instance .cleanup ,
1521+ as_function = instance .as_function ,
15221522 )
15231523
1524- def __init__ (self , name : str | None = None , remove_nodes : bool = True ) -> None :
1524+ def __init__ (
1525+ self , name : str | None = None , remove_nodes : bool = True , as_function : bool = False
1526+ ) -> None :
15251527 self .name = name or self .__class__ .__name__
15261528 self .remove_nodes = remove_nodes
1529+ self .as_function = as_function
15271530
15281531 def pattern (self , op , * args , ** kwargs ):
15291532 raise NotImplementedError ("Method 'pattern' must be implemented by derived class." )
@@ -1535,30 +1538,52 @@ def check(self, op, *args, **kwargs):
15351538 def rewrite (self , op , * args , ** kwargs ):
15361539 raise NotImplementedError ("Method 'rewrite' must be implemented by derived class." )
15371540
1541+ def setup (self ):
1542+ # Optional setup function that can be overridden by derived classes. Used to do
1543+ # per model/function initialization.
1544+ pass
1545+
1546+ def cleanup (self ):
1547+ # Optional cleanup function that can be overridden by derived classes. Used to do
1548+ # per model/function cleanup.
1549+ pass
1550+
15381551
15391552def _copy_for_function (
15401553 inputs : Sequence [ir .Value | None ], nodes : Sequence [ir .Node ], outputs : Sequence [ir .Value ]
15411554):
15421555 """Utility function to extract a subgraph out as a function."""
15431556 value_map : dict [ir .Value , ir .Value ] = {}
15441557 function_inputs : list [ir .Value ] = []
1558+ constant_nodes : list [ir .Node ] = []
15451559 for input in inputs :
15461560 # Create a function input (formal-parameter value) to represent this value:
1547- if input is None :
1548- raise NotImplementedError ("None inputs not supported." )
1549- new_value = ir .Value (
1550- name = input .name ,
1551- shape = input .shape ,
1552- type = input .type ,
1553- doc_string = input .doc_string ,
1561+ new_value = (
1562+ ir .Value (
1563+ name = input .name ,
1564+ shape = input .shape ,
1565+ type = input .type ,
1566+ doc_string = input .doc_string ,
1567+ )
1568+ if input
1569+ else ir .Value () # dummy parameter for a None input
15541570 )
1555- value_map [input ] = new_value
1571+ if input is not None :
1572+ value_map [input ] = new_value
15561573 function_inputs .append (new_value )
15571574
15581575 def copy_value (value : ir .Value | None ) -> ir .Value | None :
15591576 if value is None :
15601577 return None
15611578 if value not in value_map :
1579+ const_value = value .const_value
1580+ if const_value is not None :
1581+ # create a Constant node to represent the value
1582+ value_attr = ir .AttrTensor ("value" , const_value )
1583+ const_node = ir .Node ("" , "Constant" , [], [value_attr ])
1584+ constant_nodes .append (const_node )
1585+ value_map [value ] = result = const_node .outputs [0 ]
1586+ return result
15621587 raise ValueError (f"Value { value } not found in value_map." )
15631588 return value_map [value ]
15641589
@@ -1598,7 +1623,7 @@ def copy_node(node: ir.Node) -> ir.Node:
15981623
15991624 function_nodes = [copy_node (node ) for node in nodes ]
16001625 function_outputs = [copy_value (v ) for v in outputs ]
1601- return (function_inputs , function_nodes , function_outputs )
1626+ return (function_inputs , constant_nodes + function_nodes , function_outputs )
16021627
16031628
16041629def _get_new_overload (model : ir .Model , domain : str , name : str ) -> str :
0 commit comments