@@ -955,6 +955,7 @@ def run_example(
955955 baseline_name : str = "torch" ,
956956 rtol : float = 1e-2 ,
957957 atol : float = 1e-1 ,
958+ max_mismatch_pct : float | None = None ,
958959 bwd : bool = False ,
959960 trace_path : str | None = None ,
960961 process_group_name : str | None = None ,
@@ -969,6 +970,8 @@ def run_example(
969970 baseline_name: Name for single baseline in output (default: "torch")
970971 rtol: Relative tolerance for correctness check (default: 1e-2)
971972 atol: Absolute tolerance for correctness check (default: 1e-1)
973+ max_mismatch_pct: If set, use assert_close_with_mismatch_tolerance with this mismatch
974+ fraction tolerance instead of strict assert_close (default: None)
972975 bwd: Whether to also test backward pass (default: False)
973976 trace_path: if not None, do profiling and save trace to this path
974977 """
@@ -998,12 +1001,21 @@ def run_example(
9981001 # Clone args to avoid buffer donation issues (e.g., Pallas/TPU)
9991002 cloned_args = _clone_args (args , process_group_name = process_group_name )
10001003 result = func (* cloned_args ).clone ()
1001- torch .testing .assert_close (
1002- result .to (torch .float32 ),
1003- expected .to (torch .float32 ),
1004- rtol = rtol ,
1005- atol = atol ,
1006- )
1004+ if max_mismatch_pct is not None :
1005+ assert_close_with_mismatch_tolerance (
1006+ result .to (torch .float32 ),
1007+ expected .to (torch .float32 ),
1008+ atol = atol ,
1009+ rtol = rtol ,
1010+ max_mismatch_pct = max_mismatch_pct ,
1011+ )
1012+ else :
1013+ torch .testing .assert_close (
1014+ result .to (torch .float32 ),
1015+ expected .to (torch .float32 ),
1016+ rtol = rtol ,
1017+ atol = atol ,
1018+ )
10071019
10081020 # Test backward pass
10091021 if bwd :
0 commit comments