@@ -2483,44 +2483,6 @@ def __init__(self):
24832483 sample_inputs_func = sample_inputs_embedding_bag_padding_idx ,
24842484 supports_out = False ,
24852485 ),
2486- opinfo_core .OpInfo (
2487- "ops.aten.embedding_bag.padding_idx_none" ,
2488- op = torch .nn .functional .embedding_bag ,
2489- dtypes = common_dtype .floating_types_and_half (),
2490- sample_inputs_func = lambda op_info , device , dtype , requires_grad : [
2491- opinfo_core .SampleInput (
2492- torch .tensor (
2493- [[1.0 , 1.0 , 1.0 ], [2.0 , 2.0 , 2.0 ], [3.0 , 3.0 , 3.0 ], [4.0 , 4.0 , 4.0 ]],
2494- dtype = dtype ,
2495- device = device ,
2496- ),
2497- args = (
2498- torch .tensor ([0 , 1 , 2 , 3 ], dtype = torch .int64 , device = device ),
2499- torch .tensor ([0 , 2 ], dtype = torch .int64 , device = device ),
2500- ),
2501- kwargs = {"padding_idx" : None },
2502- )
2503- ],
2504- ),
2505- opinfo_core .OpInfo (
2506- "ops.aten.embedding_bag.padding_idx_int" ,
2507- op = torch .nn .functional .embedding_bag ,
2508- dtypes = common_dtype .floating_types_and_half (),
2509- sample_inputs_func = lambda op_info , device , dtype , requires_grad : [
2510- opinfo_core .SampleInput (
2511- torch .tensor (
2512- [[1.0 , 1.0 , 1.0 ], [2.0 , 2.0 , 2.0 ], [3.0 , 3.0 , 3.0 ]],
2513- dtype = dtype ,
2514- device = device ,
2515- ),
2516- args = (
2517- torch .tensor ([0 , 1 , 2 ], dtype = torch .int64 , device = device ),
2518- torch .tensor ([0 , 2 ], dtype = torch .int64 , device = device ),
2519- ),
2520- kwargs = {"padding_idx" : 0 },
2521- )
2522- ],
2523- ),
25242486 opinfo_core .OpInfo (
25252487 "ops.aten.embedding_renorm" ,
25262488 aten_name = "embedding_renorm" ,
0 commit comments