2929#
3030# This produces cos/sin values in a form that can be used by ORT's custom ops.
3131
32- # TODO: To apply the pattern-rewrite, we need to know the maximum position id.
33- # Need to find a way to get this information from the model or its config.
34-
3532
3633class CosSinCacheFusion (pattern .RewriteRuleClassBase ):
3734 def __init__ (
3835 self ,
3936 name : str ,
40- max_pos_id : int ,
4137 * ,
4238 cast : bool = False ,
4339 reshape : bool = False ,
@@ -47,13 +43,66 @@ def __init__(
4743 # matched nodes as part of the rewrite-step. We apply a separate final
4844 # pass to remove unused nodes.
4945 super ().__init__ (name , remove_nodes = False )
50- self ._max_pos_id = max_pos_id
46+ # TODO: Determine what should be the default max_pos_id value
47+ self ._max_pos_id = None
5148 # map from inv_freq to (cos, sin) values for transformed graph
5249 self ._inv_freq_cos_sin_cache : dict [ir .Value , tuple [ir .Value , ir .Value ]] = {}
5350 self ._reshape = reshape
5451 self ._cast = cast
5552 self ._const_freqs = const_freqs
5653
54+ @property
55+ def max_pos_id (self ) -> int | None :
56+ return self ._max_pos_id
57+
58+ @max_pos_id .setter
59+ def max_pos_id (self , max_pos_id : int ):
60+ self ._max_pos_id = max_pos_id # type: ignore[assignment]
61+
62+ def _compute_const_freqs (self , op , freqs ):
63+ """Compute cos/sin values when frequencies are constant."""
64+ angles = freqs .const_value .numpy ()
65+ cos_value = np .cos (angles )
66+ sin_value = np .sin (angles )
67+ cos_2d = op .Constant (value = ir .tensor (cos_value ))
68+ sin_2d = op .Constant (value = ir .tensor (sin_value ))
69+ return cos_2d , sin_2d
70+
71+ def _compute_dynamic_freqs (self , op , inv_freq , position_ids , dtype ):
72+ """Compute cos/sin values dynamically based on inv_freq and position_ids."""
73+ if self ._max_pos_id is not None :
74+ # Use max_pos_id from the model metadata
75+ max_pos_id = self ._max_pos_id
76+ elif position_ids .const_value is not None :
77+ # Calculate max_pos_id from the position_ids tensor
78+ max_pos_id = int (np .max (position_ids .const_value .numpy ()))
79+ else :
80+ # Dynamically compute max_pos_id from position_ids using ONNX ops
81+ inv_freq = op .Reshape (inv_freq , op .Constant (value_ints = [1 , - 1 ]))
82+ max_pos_id = op .ReduceMax (position_ids , keepdims = 0 )
83+ max_pos_id = op .Add (max_pos_id , op .Constant (value_int = 1 ))
84+ pos_id_range = op .Range (
85+ op .Constant (value_int = 0 ),
86+ max_pos_id ,
87+ op .Constant (value_int = 1 ),
88+ )
89+ pos_id_range = op .Reshape (pos_id_range , op .Constant (value_ints = [- 1 , 1 ]))
90+ pos_id_range = op .Cast (pos_id_range , to = ir .DataType .FLOAT )
91+ # Compute angles and cos/sin values
92+ angles = op .MatMul (pos_id_range , inv_freq )
93+ cos_2d = op .Cos (angles )
94+ sin_2d = op .Sin (angles )
95+ return cos_2d , sin_2d
96+
97+ # If we do not compute max_pos_id using ONNX ops, use inv_freq and position_ids
98+ # to compute angles and cos/sin values
99+ # Note: The one is added to max_pos_id as position_ids are 0-indexed
100+ # and the range of position ids should be [0, max_pos_id], max_pos_id inclusive.
101+ inv_freq_values = inv_freq .const_value .numpy ().reshape (1 , - 1 )
102+ pos_id_range = np .arange (max_pos_id + 1 , dtype = np .float32 ).reshape (- 1 , 1 )
103+ angles = np .matmul (pos_id_range , inv_freq_values )
104+ return self ._compute_const_freqs (op , angles )
105+
57106 def cleanup (self ):
58107 self ._inv_freq_cos_sin_cache .clear ()
59108
@@ -128,16 +177,11 @@ def rewrite(
128177 if inv_freq in self ._inv_freq_cos_sin_cache :
129178 cos_2d , sin_2d = self ._inv_freq_cos_sin_cache [inv_freq ]
130179 else :
180+ # Compute cos/sin values based on whether frequencies are constant
131181 if self ._const_freqs :
132- angles = freqs . const_value . numpy ( )
182+ cos_2d , sin_2d = self . _compute_const_freqs ( op , freqs )
133183 else :
134- inv_freq_values = inv_freq .const_value .numpy ().reshape (1 , - 1 )
135- pos_id_range = np .arange (self ._max_pos_id , dtype = np .float32 ).reshape (- 1 , 1 )
136- angles = np .matmul (pos_id_range , inv_freq_values )
137- cos_value = np .cos (angles )
138- sin_value = np .sin (angles )
139- cos_2d = op .Constant (value = ir .tensor (cos_value ))
140- sin_2d = op .Constant (value = ir .tensor (sin_value ))
184+ cos_2d , sin_2d = self ._compute_dynamic_freqs (op , inv_freq , position_ids , dtype )
141185 if self ._cast :
142186 cos_2d = op .Cast (cos_2d , to = dtype )
143187 sin_2d = op .Cast (sin_2d , to = dtype )
@@ -157,13 +201,11 @@ def rewrite(
157201
158202
159203_cast_const_freqs = CosSinCacheFusion .rule (
160- "CosSinCache_cast_const_freqs" , 2048 , cast = True , const_freqs = True
161- )
162- _cast = CosSinCacheFusion .rule ("CosSinCache_cast" , 2048 , cast = True , const_freqs = False )
163- _const_freqs = CosSinCacheFusion .rule (
164- "CosSinCache_const_freqs" , 2048 , cast = False , const_freqs = True
204+ "CosSinCache_cast_const_freqs" , cast = True , const_freqs = True
165205)
166- _basic = CosSinCacheFusion .rule ("CosSinCache" , 2048 , cast = False )
206+ _cast = CosSinCacheFusion .rule ("CosSinCache_cast" , cast = True , const_freqs = False )
207+ _const_freqs = CosSinCacheFusion .rule ("CosSinCache_const_freqs" , cast = False , const_freqs = True )
208+ _basic = CosSinCacheFusion .rule ("CosSinCache" , cast = False )
167209
168210cos_sin_cache_rules = pattern .RewriteRuleSet ([_cast , _cast_const_freqs , _const_freqs , _basic ])
169211
0 commit comments