@@ -208,114 +208,40 @@ def test_version_convert_gridsample_cubic(self):
208208 self .assertEqual (model .graph .node (4 ).version , 20 )
209209 self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "cubic" )
210210
211- def test_version_convert_function_nodes (self ):
212- """Test that version converter processes nodes inside model functions."""
211+ def test_version_convert_inline (self ):
213212 model = ir .from_onnx_text (
214213 """
215- <ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
216- agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
217- {
218- output = pkg.custom.dft_func (input_x)
219- }
220-
221- <domain: "pkg.custom", opset_import: [ "" : 18]>
222- dft_func (x) => (result) {
223- shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
224- reshape_x = Reshape (x, shape_a)
225- dft = DFT <axis = 2, onesided = 1> (reshape_x)
226- shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
227- result = Reshape (dft, shape_c)
228- }
229- """
230- )
231- # Verify the function exists with correct initial state
232- self .assertEqual (len (model .functions ), 1 )
233- func = model .functions [("pkg.custom" , "dft_func" , "" )]
234- self .assertEqual (len (func ), 5 ) # 5 nodes in the function
235-
236- target_version = 20
237- version_converter .convert_version (model , target_version = target_version )
238- self .assertEqual (model .opset_imports ["" ], target_version )
239-
240- # Verify that nodes inside the function were version-converted
241- func = model .functions [("pkg.custom" , "dft_func" , "" )]
242- self .assertEqual (func [0 ].op_type , "Constant" )
243- self .assertEqual (func [0 ].version , 20 )
244- self .assertEqual (func [1 ].op_type , "Reshape" )
245- self .assertEqual (func [1 ].version , 20 )
246- # After DFT adapter, a new Constant node is inserted for dft_length
247- self .assertEqual (func [2 ].op_type , "Constant" )
248- self .assertEqual (func [2 ].version , 20 )
249- self .assertEqual (func [3 ].op_type , "DFT" )
250- self .assertEqual (func [3 ].version , 20 )
251- self .assertEqual (len (func [3 ].inputs ), 3 ) # DFT 19->20 adds dft_length input
252-
253- def test_version_convert_function_with_control_flow_subgraph (self ):
254- """Test that version converter processes subgraphs inside control flow nodes in functions."""
255- model = ir .from_onnx_text (
256- """
257- <ir_version: 8, opset_import: [ "" : 18, "pkg.custom": 1]>
258- agraph (float[4, 512, 512] input_x, bool cond) => (float[4, 257, 64, 2] output)
214+ <ir_version: 8, opset_import: [ "" : 18]>
215+ agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
259216 {
260- output = pkg.custom.conditional_dft (input_x, cond)
217+ shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
218+ reshape_x = Reshape (input_x, shape_a)
219+ shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
220+ reshape_y = Reshape (input_x, shape_b)
221+ gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
222+ output = foo(gridsample)
261223 }
262224
263- <domain: "pkg.custom", opset_import: [ "" : 18]>
264- conditional_dft (x, cond) => (result) {
265- result = If (cond) <then_branch: graph = then_graph () => (out) {
266- shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512, 1}>()
267- reshape_x = Reshape (x, shape_a)
268- dft = DFT <axis = 2, onesided = 1> (reshape_x)
269- shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
270- out = Reshape (dft, shape_c)
271- }, else_branch: graph = else_graph () => (out) {
272- shape_c = Constant<value: tensor = int64[4] {4, 257, 64, 2}>()
273- out = Reshape (x, shape_c)
274- }>
225+ <opset_import: [ "" : 18]>
226+ foo (x) => (dft) {
227+ dft = DFT <axis = 2, onesided = 1> (x)
275228 }
276229 """
277230 )
278- # Verify the function exists with correct initial state
279- self .assertEqual (len (model .functions ), 1 )
280- func = model .functions [("pkg.custom" , "conditional_dft" , "" )]
281- self .assertEqual (len (func ), 1 ) # 1 node (If) in the function
282-
283- # Verify the If node has subgraphs
284- if_node = func [0 ]
285- self .assertEqual (if_node .op_type , "If" )
286- then_branch = if_node .attributes ["then_branch" ].as_graph ()
287- else_branch = if_node .attributes ["else_branch" ].as_graph ()
288- self .assertEqual (len (then_branch ), 5 ) # 5 nodes in then_branch
289- self .assertEqual (len (else_branch ), 2 ) # 2 nodes in else_branch
290-
291231 target_version = 20
292- # Use internal API to test function version conversion without inlining
293232 version_converter .convert_version (model , target_version = target_version )
294233 self .assertEqual (model .opset_imports ["" ], target_version )
295234
296- # Verify nodes inside the function's If node subgraphs were version-converted
297- func = model .functions [("pkg.custom" , "conditional_dft" , "" )]
298- if_node = func [0 ]
299- self .assertEqual (if_node .op_type , "If" )
300- self .assertEqual (if_node .version , 20 )
301-
302- # Check then_branch subgraph nodes
303- then_branch = if_node .attributes ["then_branch" ].as_graph ()
304- # After DFT adapter, a new Constant node is inserted for dft_length
305- self .assertEqual (len (then_branch ), 6 ) # 5 + 1 new Constant for DFT
306- dft_node = None
307- for node in then_branch :
308- self .assertEqual (node .version , 20 )
309- if node .op_type == "DFT" :
310- dft_node = node
311- self .assertIsNotNone (dft_node )
312- self .assertEqual (len (dft_node .inputs ), 3 ) # DFT 19->20 adds dft_length input
313-
314- # Check else_branch subgraph nodes
315- else_branch = if_node .attributes ["else_branch" ].as_graph ()
316- self .assertEqual (len (else_branch ), 2 )
317- for node in else_branch :
318- self .assertEqual (node .version , 20 )
235+ self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
236+ self .assertEqual (model .graph .node (0 ).version , 20 )
237+ self .assertEqual (model .graph .node (1 ).op_type , "Reshape" )
238+ self .assertEqual (model .graph .node (1 ).version , 20 )
239+ self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
240+ self .assertEqual (model .graph .node (4 ).version , 20 )
241+ self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "linear" )
242+ self .assertEqual (model .graph .node (6 ).op_type , "DFT" )
243+ self .assertEqual (model .graph .node (6 ).version , 20 )
244+ self .assertEqual (len (model .graph .node (6 ).inputs ), 3 )
319245
320246
321247class VersionConverter20to21Test (unittest .TestCase ):
0 commit comments