diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml new file mode 100644 index 000000000..4f39c463b --- /dev/null +++ b/.github/workflows/benchmark_tpu.yml @@ -0,0 +1,140 @@ +name: Benchmark TPU + +on: + workflow_call: + inputs: + kernels: + required: true + type: string + +permissions: + contents: read + +jobs: + benchmark: + name: benchmark-tpu-pallas + + env: + HELION_BACKEND: pallas + HELION_AUTOTUNE_LOG_LEVEL: INFO + HELION_AUTOTUNE_EFFORT: quick + + runs-on: linux.google.tpuv7x.1 + + defaults: + run: + shell: bash -l {0} + + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment + run: | + uv venv --python 3.12 + + - name: Install PyTorch (CPU nightly) + run: | + source .venv/bin/activate + uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + + - name: Install Helion + run: | + source .venv/bin/activate + uv pip install setuptools ninja + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]' + python -c "import helion; print(helion.__name__)" + + - name: Install TPU dependencies (Pallas) + run: | + set -euxo pipefail + source .venv/bin/activate + uv pip install \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --pre \ + 'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' + # Install Bazel + if ! command -v bazel &> /dev/null; then + sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel + sudo chmod +x /usr/local/bin/bazel + fi + # Install gcloud CLI if not present (needed for Secret Manager) + if ! command -v gcloud &> /dev/null; then + sudo apt-get install -y apt-transport-https ca-certificates gpg curl + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg + echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee /etc/apt/sources.list.d/google-cloud-sdk.list + sudo apt-get update && sudo apt-get install -y google-cloud-cli + fi + # Clone torch_tpu via GCP Secret Manager SSH key (same as pytorch CI) + TORCH_TPU_COMMIT=$(cat .github/ci_commit_pins/torch_tpu.txt) + set +x + gcloud secrets versions access latest \ + --secret="torchtpu-read-key" \ + --project="ml-velocity-actions-testing" > /tmp/torch_tpu_ssh_key + set -x + chmod 600 /tmp/torch_tpu_ssh_key + GIT_SSH_COMMAND="ssh -i /tmp/torch_tpu_ssh_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" \ + git clone git@github.com:google-pytorch/torch_tpu.git /tmp/torch_tpu + rm -f /tmp/torch_tpu_ssh_key + cd /tmp/torch_tpu + git checkout "${TORCH_TPU_COMMIT}" + # Build torch_tpu wheel + export TORCH_SOURCE=$(python -c "import torch; import os; print(os.path.dirname(os.path.dirname(torch.__file__)))") + export SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") + bazel build -c opt //ci/wheel:torch_tpu_wheel --config=helion_public_caching_readwrite --define WHEEL_VERSION=0.1.0 --define TORCH_SOURCE=local --action_env=PYTHONPATH=$TORCH_SOURCE:$SITE_PACKAGES --action_env=JAX_PLATFORMS=cpu + uv pip install bazel-bin/ci/wheel/*.whl + cd - + rm -rf /tmp/torch_tpu + # Verify + python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')" + + - name: Run TPU Benchmark + run: | + source .venv/bin/activate + + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + + KERNELS="${{ inputs.kernels }}" + echo "==========================================" + echo "TPU Benchmark: autotuning pass" + echo "Kernels: $KERNELS" + echo "==========================================" + + # First pass: autotune (populates cache) + python benchmarks/run_tpu.py --kernel "$KERNELS" --num-shapes 1 + + # Let TPU cool down + sleep 1m + + echo "==========================================" + echo "TPU Benchmark: cache-hit verification pass" + echo "==========================================" + + # Second pass: verify cache hits and record results + HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 \ + python benchmarks/run_tpu.py \ + --kernel "$KERNELS" \ + --num-shapes 1 \ + --output "$TEST_REPORTS_DIR/helionbench.json" + + if [[ -s "$TEST_REPORTS_DIR/helionbench.json" ]]; then + cat "$TEST_REPORTS_DIR/helionbench.json" + else + echo "helionbench.json is missing or empty (some kernels may have failed)" + fi + + - name: Upload the benchmark results to GitHub + uses: actions/upload-artifact@v7 + with: + name: benchmark-results-tpu + path: test/test-reports diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml new file mode 100644 index 000000000..0914deb44 --- /dev/null +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -0,0 +1,28 @@ +name: Benchmark TPU Nightly + +on: + push: # TODO: remove before merging — temporary trigger for CI testing + branches: + - yifeixu/tpu-nightly-benchmark + schedule: + - cron: '0 10 * * *' # Runs at 2 AM PST (10 AM UTC) + workflow_dispatch: + inputs: + kernels: + description: 'Comma-separated list of kernels to benchmark' + required: false + type: string + # Excluded kernels: + # layer_norm: OOB slice when reduction_loops doesn't evenly divide the reduction dim (gh#1937) + default: "exp,add,softmax_two_pass,welford,attention,bmm,geglu,grpo_loss,jagged_hstu_attn,low_mem_dropout,swiglu" + +permissions: + contents: read + +jobs: + benchmark-tpu: + uses: ./.github/workflows/benchmark_tpu.yml + permissions: + contents: read + with: + kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford,attention,bmm,geglu,grpo_loss,jagged_hstu_attn,low_mem_dropout,swiglu' }} diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py new file mode 100644 index 000000000..000236559 --- /dev/null +++ b/benchmarks/run_tpu.py @@ -0,0 +1,435 @@ +"""TPU/Pallas benchmark runner for Helion examples. + +Runs selected Helion examples with autotuning on TPU and reports results +in the same JSON format as the GPU benchmark runner (benchmarks/run.py). + +Usage: + # Run all default kernels + HELION_BACKEND=pallas python benchmarks/run_tpu.py + + # Run specific kernels + HELION_BACKEND=pallas python benchmarks/run_tpu.py --kernel exp,add + + # Output results to JSON (compatible with pytorch benchmark hub) + HELION_BACKEND=pallas python benchmarks/run_tpu.py --output results.json + + # List available kernels + HELION_BACKEND=pallas python benchmarks/run_tpu.py --list-kernels +""" + +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from dataclasses import field +import functools +import importlib.util +import json +import os +from pathlib import Path +import signal +import sys +import time +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import torch + +from helion._testing import DEVICE +from helion._testing import run_example + +if TYPE_CHECKING: + import types + +EXAMPLES_DIR = Path(__file__).parent.parent / "examples" + + +# Shape generators for multi-shape benchmarking. +# Each returns a list of (label, args_tuple) pairs. +def _exp_shapes() -> list[tuple[str, tuple[Any, ...]]]: + sizes = [1024, 4096, 16384, 65536, 262144, 1048576] + return [ + ( + f"[{n}]", + (torch.randn(n, device=DEVICE, dtype=torch.float32, requires_grad=True),), + ) + for n in sizes + ] + + +def _add_shapes() -> list[tuple[str, tuple[Any, ...]]]: + sizes = [(128, 128), (256, 256), (512, 512), (1024, 1024), (2048, 2048)] + return [ + ( + f"[{m},{n}]", + ( + torch.randn(m, n, device=DEVICE, dtype=torch.bfloat16), + torch.randn(m, n, device=DEVICE, dtype=torch.bfloat16), + ), + ) + for m, n in sizes + ] + + +def _softmax_shapes() -> list[tuple[str, tuple[Any, ...]]]: + shapes = [(1024, 256), (1024, 512), (1024, 1024), (1024, 2048), (1024, 4096)] + return [ + ( + f"[{m},{n}]", + (torch.randn(m, n, device=DEVICE, dtype=torch.bfloat16),), + ) + for m, n in shapes + ] + + +# Kernel mappings for TPU/Pallas benchmarks. +# Format: kernel_name -> (module_file, kernel_fn_name, baseline_fn, shapes_fn) +# module_file: filename in examples/ (without .py) +# kernel_fn_name: attribute name of the helion kernel in the module +# baseline_fn: callable that produces reference output (None = call main()) +# shapes_fn: callable returning list of (label, args) pairs (None = call main()) +# +# This list contains only kernels that reliably pass on Pallas/TPU. +KernelMapping = tuple[ + str, + str, + Callable[..., Any] | None, + Callable[[], list[tuple[str, tuple[Any, ...]]]] | None, +] +KERNEL_MAPPINGS: dict[str, KernelMapping] = { + "exp": ("exp", "exp", torch.exp, _exp_shapes), + "add": ("add", "add", torch.add, _add_shapes), + "softmax_two_pass": ( + "softmax", + "softmax_two_pass", + functools.partial(torch.softmax, dim=-1), + _softmax_shapes, + ), + "welford": ("welford", "welford", None, None), + "attention": ("attention", "attention", None, None), + "bmm": ("bmm", "bmm", None, None), + "geglu": ("geglu", "geglu", None, None), + "grpo_loss": ("grpo_loss", "grpo_loss_forward", None, None), + "jagged_hstu_attn": ("jagged_hstu_attn", "_helion_jagged_attention_kernel", None, None), + "low_mem_dropout": ("low_mem_dropout", "low_mem_dropout", None, None), + "swiglu": ("swiglu", "swiglu_fwd", None, None), +} + + +@dataclass +class ShapeResult: + shape: str + passed: bool + kernel_time_ms: float = 0.0 + baseline_time_ms: float = 0.0 + speedup: float = 0.0 + error: str | None = None + + +@dataclass +class KernelResult: + name: str + passed: bool + kernel_time_ms: float = 0.0 + error: str | None = None + shape_results: list[ShapeResult] = field(default_factory=list) + + +def import_example(module_file: str) -> types.ModuleType: + """Import an example module by filename.""" + module_path = EXAMPLES_DIR / f"{module_file}.py" + spec = importlib.util.spec_from_file_location( + f"examples.{module_file}", module_path + ) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +KERNEL_TIMEOUT = int(os.environ.get("HELION_BENCHMARK_KERNEL_TIMEOUT", "1200")) +NUM_SHAPES: int | None = None # Set from CLI; None means all shapes + + +class _KernelTimeout(Exception): + """Raised by SIGALRM when a kernel exceeds its timeout.""" + + +def _alarm_handler(signum: int, frame: object) -> None: + raise _KernelTimeout + + +def run_kernel(name: str) -> KernelResult: + """Run a single kernel benchmark with a signal-based timeout. + + Uses SIGALRM instead of multiprocessing to avoid fork-after-TPU-init + deadlocks on Linux. + """ + old_handler = signal.signal(signal.SIGALRM, _alarm_handler) + signal.alarm(KERNEL_TIMEOUT) + try: + return run_kernel_inner(name) + except _KernelTimeout: + return KernelResult( + name=name, + passed=False, + error=f"Timed out after {KERNEL_TIMEOUT}s", + ) + except Exception as e: + return KernelResult(name=name, passed=False, error=str(e)) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +def run_kernel_inner(name: str) -> KernelResult: + """Run a single kernel benchmark: accuracy check + timing vs baseline.""" + if name not in KERNEL_MAPPINGS: + return KernelResult(name=name, passed=False, error=f"Unknown kernel: {name}") + + module_file, kernel_fn_name, baseline_fn, shapes_fn = KERNEL_MAPPINGS[name] + + try: + mod = import_example(module_file) + kernel_fn = getattr(mod, kernel_fn_name) + + # For kernels with None baseline/shapes, call main() directly + # (they have complex setup that's hard to replicate here) + if baseline_fn is None or shapes_fn is None: + start = time.perf_counter() + mod.main() + elapsed = time.perf_counter() - start + return KernelResult( + name=name, + passed=True, + kernel_time_ms=elapsed * 1000, + ) + + shapes = shapes_fn() + if NUM_SHAPES is not None: + shapes = shapes[:NUM_SHAPES] + all_passed = True + shape_results: list[ShapeResult] = [] + + for label, args in shapes: + print(f" Shape {label}:", file=sys.stderr) + try: + timings = run_example(kernel_fn, baseline_fn, args) + kernel_ms = timings.get("helion", 0.0) + baseline_ms = timings.get("torch", 0.0) + speedup = baseline_ms / kernel_ms if kernel_ms > 0 else 0.0 + shape_results.append( + ShapeResult( + shape=label, + passed=True, + kernel_time_ms=kernel_ms, + baseline_time_ms=baseline_ms, + speedup=speedup, + ) + ) + except Exception as e: + print(f" FAIL: {e}", file=sys.stderr) + shape_results.append( + ShapeResult(shape=label, passed=False, error=str(e)) + ) + all_passed = False + + return KernelResult(name=name, passed=all_passed, shape_results=shape_results) + + except Exception as e: + return KernelResult(name=name, passed=False, error=str(e)) + + +def write_results_json(output: str, results: list[KernelResult]) -> None: + """Write results in the same JSON format as benchmarks/run.py for pytorch benchmark hub.""" + device = os.environ.get("HELION_BACKEND", "pallas") + records: list[dict[str, Any]] = [] + for result in results: + if result.shape_results: + for sr in result.shape_results: + records.append( + { + "benchmark": { + "name": "Helion TPU Benchmark", + "extra_info": {"device": device}, + }, + "model": {"name": result.name}, + "metric": { + "name": "accuracy", + "benchmark_values": [1.0 if sr.passed else 0.0], + }, + "shape": [sr.shape], + } + ) + else: + records.append( + { + "benchmark": { + "name": "Helion TPU Benchmark", + "extra_info": {"device": device}, + }, + "model": {"name": result.name}, + "metric": { + "name": "accuracy", + "benchmark_values": [1.0 if result.passed else 0.0], + }, + "shape": [], + } + ) + if result.kernel_time_ms > 0: + records.append( + { + "benchmark": { + "name": "Helion TPU Benchmark", + "extra_info": {"device": device}, + }, + "model": {"name": result.name}, + "metric": { + "name": "kernel_time_ms", + "benchmark_values": [result.kernel_time_ms], + }, + "shape": [], + } + ) + + if os.path.exists(output): + try: + with open(output) as f: + existing = json.load(f) + if isinstance(existing, list): + records = existing + records + except (OSError, json.JSONDecodeError): + pass + + with open(output, "w") as f: + json.dump(records, f, indent=2) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="TPU/Pallas benchmark runner for Helion examples", + allow_abbrev=False, + ) + parser.add_argument( + "--kernel", + "--op", + type=str, + dest="kernel", + help="Comma-separated list of kernels to run. If not specified, runs all kernels.", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON file path (compatible with pytorch benchmark hub)", + ) + parser.add_argument( + "--num-shapes", + type=int, + default=None, + help="Max number of shapes to benchmark per kernel (default: all)", + ) + parser.add_argument( + "--list-kernels", + action="store_true", + help="List available kernel names and exit", + ) + args = parser.parse_args() + + global NUM_SHAPES + NUM_SHAPES = args.num_shapes + + if args.list_kernels: + for name in KERNEL_MAPPINGS: + print(name) + return + + if args.kernel: + kernel_names = [k.strip() for k in args.kernel.split(",") if k.strip()] + # Validate + for name in kernel_names: + if name not in KERNEL_MAPPINGS: + print(f"Error: Unknown kernel '{name}'", file=sys.stderr) + print( + f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", + file=sys.stderr, + ) + sys.exit(1) + else: + kernel_names = list(KERNEL_MAPPINGS.keys()) + + print( + f"Running {len(kernel_names)} TPU kernels: {', '.join(kernel_names)}", + file=sys.stderr, + ) + print( + f"HELION_BACKEND={os.environ.get('HELION_BACKEND', '(not set)')}", + file=sys.stderr, + ) + print("=" * 65, file=sys.stderr) + + results: list[KernelResult] = [] + for name in kernel_names: + print(f"\n{'=' * 65}", file=sys.stderr) + print(f"Kernel: {name}", file=sys.stderr) + print(f"{'=' * 65}", file=sys.stderr) + result = run_kernel(name) + results.append(result) + + status = "PASS" if result.passed else "FAIL" + print(f" Status: {status}", file=sys.stderr) + if result.error: + print(f" Error: {result.error}", file=sys.stderr) + if result.shape_results: + for sr in result.shape_results: + sr_status = "PASS" if sr.passed else "FAIL" + print(f" {sr.shape}: {sr_status}", file=sys.stderr) + + # Summary table + print(f"\n{'=' * 75}", file=sys.stderr) + print("Summary", file=sys.stderr) + print(f"{'=' * 75}", file=sys.stderr) + print( + f"{'Kernel':<22} {'Shape':<16} {'Status':<8} {'Helion (ms)':<14} {'Torch (ms)':<14} {'Speedup':<10}", + file=sys.stderr, + ) + print(f"{'-' * 75}", file=sys.stderr) + for result in results: + if result.shape_results: + for sr in result.shape_results: + status = "PASS" if sr.passed else "FAIL" + kernel_str = ( + f"{sr.kernel_time_ms:.4f}" if sr.kernel_time_ms > 0 else "-" + ) + baseline_str = ( + f"{sr.baseline_time_ms:.4f}" if sr.baseline_time_ms > 0 else "-" + ) + speedup_str = f"{sr.speedup:.2f}x" if sr.speedup > 0 else "-" + print( + f"{result.name:<22} {sr.shape:<16} {status:<8} {kernel_str:<14} {baseline_str:<14} {speedup_str:<10}", + file=sys.stderr, + ) + else: + status = "PASS" if result.passed else "FAIL" + time_str = ( + f"{result.kernel_time_ms:.1f}" if result.kernel_time_ms > 0 else "-" + ) + print( + f"{result.name:<22} {'main()':<16} {status:<8} {time_str:<14} {'-':<14} {'-':<10}", + file=sys.stderr, + ) + + passed = sum(1 for r in results if r.passed) + total = len(results) + print(f"{'-' * 75}", file=sys.stderr) + print(f"Total: {passed}/{total} passed", file=sys.stderr) + print(f"{'=' * 75}\n", file=sys.stderr) + + if args.output: + write_results_json(args.output, results) + print(f"Results written to {args.output}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/helion/_testing.py b/helion/_testing.py index 3183074e4..8094c15ae 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -1005,9 +1005,12 @@ def run_example( bwd: bool = False, trace_path: str | None = None, process_group_name: str | None = None, -) -> None: +) -> dict[str, float]: """Run complete example: correctness check + benchmark. + Returns: + Dictionary mapping implementation names to their benchmark times in ms. + Args: kernel_fn: Single kernel function, or dict of {name: function} for multiple kernel variants baseline_fn: Single baseline function or dict of {name: function} for multiple baselines @@ -1187,6 +1190,8 @@ def run_example( print(f"{'=' * 65}\n", file=sys.stderr) + return all_times + def _assert_example_result_close( result: object,