Skip to content

Commit 6e41cfe

Browse files
committed
fix: fixed bugs and issues with test cases
1 parent 7192035 commit 6e41cfe

2 files changed

Lines changed: 6 additions & 16 deletions

File tree

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,9 +2211,9 @@ def __init__(self):
22112211
supports_out=False,
22122212
),
22132213
opinfo_core.OpInfo(
2214-
"test_embedding_bag_with_padding_idx_none",
2214+
"ops.aten.embedding_bag.padding_idx_none",
22152215
op=torch.nn.functional.embedding_bag,
2216-
dtypes=(torch.float32,),
2216+
dtypes=common_dtype.floating_types_and_half(),
22172217
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
22182218
opinfo_core.SampleInput(
22192219
torch.tensor(
@@ -2230,9 +2230,9 @@ def __init__(self):
22302230
],
22312231
),
22322232
opinfo_core.OpInfo(
2233-
"test_embedding_bag_with_padding_idx_int",
2233+
"ops.aten.embedding_bag.padding_idx_int",
22342234
op=torch.nn.functional.embedding_bag,
2235-
dtypes=(torch.float32,),
2235+
dtypes=common_dtype.floating_types_and_half(),
22362236
sample_inputs_func=lambda op_info, device, dtype, requires_grad: [
22372237
opinfo_core.SampleInput(
22382238
torch.tensor(

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,12 +1059,12 @@ def _where_input_wrangler(
10591059
reason="fixme: results mismatch in torch nightly.",
10601060
),
10611061
TorchLibOpInfo(
1062-
"test_embedding_bag_with_padding_idx_none",
1062+
"ops.aten.embedding_bag.padding_idx_none",
10631063
core_ops.aten_embedding_bag,
10641064
input_wrangler=_embedding_bag_input_wrangler,
10651065
),
10661066
TorchLibOpInfo(
1067-
"test_embedding_bag_with_padding_idx_int",
1067+
"ops.aten.embedding_bag.padding_idx_int",
10681068
core_ops.aten_embedding_bag,
10691069
input_wrangler=_embedding_bag_input_wrangler,
10701070
),
@@ -1076,16 +1076,6 @@ def _where_input_wrangler(
10761076
compare_shape_only_for_output=(1, 2, 3),
10771077
input_wrangler=_embedding_bag_input_wrangler,
10781078
),
1079-
TorchLibOpInfo(
1080-
"test_embedding_bag_with_padding_idx_none",
1081-
core_ops.aten_embedding_bag,
1082-
input_wrangler=_embedding_bag_input_wrangler,
1083-
),
1084-
TorchLibOpInfo(
1085-
"test_embedding_bag_with_padding_idx_int",
1086-
core_ops.aten_embedding_bag,
1087-
input_wrangler=_embedding_bag_input_wrangler,
1088-
),
10891079
TorchLibOpInfo(
10901080
"ops.aten.embedding_renorm",
10911081
core_ops.aten_embedding_renorm,

0 commit comments

Comments
 (0)