@@ -945,6 +945,7 @@ def run_example(
945945 baseline_name : str = "torch" ,
946946 rtol : float = 1e-2 ,
947947 atol : float = 1e-1 ,
948+ max_mismatch_pct : float | None = None ,
948949 bwd : bool = False ,
949950 trace_path : str | None = None ,
950951 process_group_name : str | None = None ,
@@ -959,6 +960,8 @@ def run_example(
959960 baseline_name: Name for single baseline in output (default: "torch")
960961 rtol: Relative tolerance for correctness check (default: 1e-2)
961962 atol: Absolute tolerance for correctness check (default: 1e-1)
963+ max_mismatch_pct: If set, use assert_close_with_mismatch_tolerance with this mismatch
964+ fraction tolerance instead of strict assert_close (default: None)
962965 bwd: Whether to also test backward pass (default: False)
963966 trace_path: if not None, do profiling and save trace to this path
964967 """
@@ -988,12 +991,21 @@ def run_example(
988991 # Clone args to avoid buffer donation issues (e.g., Pallas/TPU)
989992 cloned_args = _clone_args (args , process_group_name = process_group_name )
990993 result = func (* cloned_args ).clone ()
991- torch .testing .assert_close (
992- result .to (torch .float32 ),
993- expected .to (torch .float32 ),
994- rtol = rtol ,
995- atol = atol ,
996- )
994+ if max_mismatch_pct is not None :
995+ assert_close_with_mismatch_tolerance (
996+ result .to (torch .float32 ),
997+ expected .to (torch .float32 ),
998+ atol = atol ,
999+ rtol = rtol ,
1000+ max_mismatch_pct = max_mismatch_pct ,
1001+ )
1002+ else :
1003+ torch .testing .assert_close (
1004+ result .to (torch .float32 ),
1005+ expected .to (torch .float32 ),
1006+ rtol = rtol ,
1007+ atol = atol ,
1008+ )
9971009
9981010 # Test backward pass
9991011 if bwd :
0 commit comments