From 5d5a4d9619a02a01b501e6b5ab3643453f7ab646 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 1 Apr 2026 13:38:09 -0700 Subject: [PATCH 01/14] Add TPU nightly benchmark workflow and runner Add a nightly CI workflow that runs Helion examples with autotuning on TPU, with results published to pytorch benchmark hub. --- .github/workflows/benchmark_tpu.yml | 173 +++++++++ .github/workflows/benchmark_tpu_nightly.yml | 25 ++ benchmarks/run_tpu.py | 385 ++++++++++++++++++++ helion/_testing.py | 7 +- 4 files changed, 589 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/benchmark_tpu.yml create mode 100644 .github/workflows/benchmark_tpu_nightly.yml create mode 100644 benchmarks/run_tpu.py diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml new file mode 100644 index 000000000..de0e6b495 --- /dev/null +++ b/.github/workflows/benchmark_tpu.yml @@ -0,0 +1,173 @@ +name: Benchmark TPU + +on: + workflow_call: + inputs: + kernels: + required: true + type: string + +jobs: + benchmark: + name: benchmark-tpu-pallas + + env: + HELION_BACKEND: pallas + HELION_AUTOTUNE_LOG_LEVEL: INFO + + runs-on: linux.google.tpuv7x.1 + + defaults: + run: + shell: bash -l {0} + + outputs: + benchmark-metadata: ${{ steps.gather-benchmark-metadata.outputs.benchmark-metadata }} + runners-info: ${{ steps.gather-runners-info.outputs.runners-info }} + dependencies: ${{ steps.gather-dependencies.outputs.dependencies }} + + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment + run: | + uv venv --python 3.12 + + - name: Install PyTorch (CPU nightly) + run: | + source .venv/bin/activate + uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + + - name: Install Helion + run: | + source .venv/bin/activate + uv pip install setuptools ninja + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]' + python -c "import helion; print(helion.__name__)" + + - name: Install TPU dependencies (Pallas) + run: | + set -euxo pipefail + source .venv/bin/activate + uv pip install \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --pre \ + 'jax==0.9.1' 'jaxlib==0.9.1' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' + # Install Bazel + if ! command -v bazel &> /dev/null; then + sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel + sudo chmod +x /usr/local/bin/bazel + fi + # Install gcloud CLI if not present (needed for Secret Manager) + if ! command -v gcloud &> /dev/null; then + sudo apt-get install -y apt-transport-https ca-certificates gnupg curl + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg + echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee /etc/apt/sources.list.d/google-cloud-sdk.list + sudo apt-get update && sudo apt-get install -y google-cloud-cli + fi + # Clone torch_tpu via GCP Secret Manager SSH key + TORCH_TPU_COMMIT=$(cat .github/ci_commit_pins/torch_tpu.txt) + set +x + gcloud secrets versions access latest \ + --secret="torchtpu-readonly-key" \ + --project="ml-velocity-actions-testing" > /tmp/torch_tpu_ssh_key + set -x + chmod 600 /tmp/torch_tpu_ssh_key + GIT_SSH_COMMAND="ssh -i /tmp/torch_tpu_ssh_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" \ + git clone git@github.com:google-ml-infra/torch_tpu.git /tmp/torch_tpu + rm -f /tmp/torch_tpu_ssh_key + cd /tmp/torch_tpu + git checkout "${TORCH_TPU_COMMIT}" + # Build torch_tpu wheel + export TORCH_SOURCE=$(python -c "import torch; import os; print(os.path.dirname(os.path.dirname(torch.__file__)))") + export SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") + bazel build -c opt //ci/wheel:torch_tpu_wheel --config=helion_public_caching_readwrite --define WHEEL_VERSION=0.1.0 --define TORCH_SOURCE=local --action_env=PYTHONPATH=$TORCH_SOURCE:$SITE_PACKAGES --action_env=JAX_PLATFORMS=cpu + uv pip install bazel-bin/ci/wheel/*.whl + cd - + rm -rf /tmp/torch_tpu + # Verify + python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')" + + - name: Run TPU Benchmark + run: | + source .venv/bin/activate + + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + + KERNELS="${{ inputs.kernels }}" + echo "==========================================" + echo "TPU Benchmark: autotuning pass" + echo "Kernels: $KERNELS" + echo "==========================================" + + # First pass: autotune (populates cache) + python benchmarks/run_tpu.py --kernel "$KERNELS" + + # Let TPU cool down + sleep 1m + + echo "==========================================" + echo "TPU Benchmark: cache-hit verification pass" + echo "==========================================" + + # Second pass: verify cache hits and record results + HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 \ + python benchmarks/run_tpu.py \ + --kernel "$KERNELS" \ + --output "$TEST_REPORTS_DIR/helionbench.json" + + if [[ ! -s "$TEST_REPORTS_DIR/helionbench.json" ]]; then + echo "helionbench.json is missing or empty" + exit 1 + fi + cat "$TEST_REPORTS_DIR/helionbench.json" + + - name: Gather benchmark metadata + id: gather-benchmark-metadata + uses: pytorch/test-infra/.github/actions/gather-benchmark-metadata@main + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + venv: .venv/bin/activate + + - name: Gather runners info + id: gather-runners-info + uses: pytorch/test-infra/.github/actions/gather-runners-info@main + with: + venv: .venv/bin/activate + + - name: Gather dependencies + id: gather-dependencies + uses: pytorch/test-infra/.github/actions/gather-dependencies@main + with: + venv: .venv/bin/activate + + - name: Upload the benchmark results to GitHub + uses: actions/upload-artifact@v7 + with: + name: benchmark-results-tpu + path: test/test-reports + + upload-benchmark-results: + needs: benchmark + uses: pytorch/test-infra/.github/workflows/upload_benchmark_results.yml@main + permissions: + id-token: write + contents: read + with: + benchmark-artifact: benchmark-results-tpu + benchmark-metadata: ${{ needs.benchmark.outputs.benchmark-metadata }} + runners-info: ${{ needs.benchmark.outputs.runners-info }} + dependencies: ${{ needs.benchmark.outputs.dependencies }} + schema-version: v3 + dry-run: false diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml new file mode 100644 index 000000000..07bf6ab50 --- /dev/null +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -0,0 +1,25 @@ +name: Benchmark TPU Nightly + +on: + schedule: + - cron: '0 10 * * *' # Runs at 2 AM PST (10 AM UTC) + workflow_dispatch: + inputs: + kernels: + description: 'Comma-separated list of kernels to benchmark' + required: false + type: string + # Excluded kernels: + # rms_norm: InductorLoweringError in torch.mean reduction codegen for fori_loop/emit_pipeline + # geglu/swiglu: autotuning takes >15min per kernel (large shape 8x2048x4096), many configs fail to compile + # low_mem_dropout: ~37% element accuracy mismatch on all configs except block_sizes=[128] + default: "exp,add,softmax_two_pass,welford,layer_norm" + +jobs: + benchmark-tpu: + uses: ./.github/workflows/benchmark_tpu.yml + permissions: + id-token: write + contents: read + with: + kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford,layer_norm' }} diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py new file mode 100644 index 000000000..6759106af --- /dev/null +++ b/benchmarks/run_tpu.py @@ -0,0 +1,385 @@ +"""TPU/Pallas benchmark runner for Helion examples. + +Runs selected Helion examples with autotuning on TPU and reports results +in the same JSON format as the GPU benchmark runner (benchmarks/run.py). + +Usage: + # Run all default kernels + HELION_BACKEND=pallas python benchmarks/run_tpu.py + + # Run specific kernels + HELION_BACKEND=pallas python benchmarks/run_tpu.py --kernel exp,add + + # Output results to JSON (compatible with pytorch benchmark hub) + HELION_BACKEND=pallas python benchmarks/run_tpu.py --output results.json + + # List available kernels + HELION_BACKEND=pallas python benchmarks/run_tpu.py --list-kernels +""" + +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from dataclasses import field +import functools +import importlib.util +import json +import os +from pathlib import Path +import sys +import time +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import torch + +from helion._testing import DEVICE +from helion._testing import run_example + +if TYPE_CHECKING: + import types + +EXAMPLES_DIR = Path(__file__).parent.parent / "examples" + + +# Shape generators for multi-shape benchmarking. +# Each returns a list of (label, args_tuple) pairs. +def _exp_shapes() -> list[tuple[str, tuple[Any, ...]]]: + sizes = [1024, 4096, 16384, 65536, 262144, 1048576] + return [ + ( + f"[{n}]", + (torch.randn(n, device=DEVICE, dtype=torch.float32, requires_grad=True),), + ) + for n in sizes + ] + + +def _add_shapes() -> list[tuple[str, tuple[Any, ...]]]: + sizes = [(128, 128), (256, 256), (512, 512), (1024, 1024), (2048, 2048)] + return [ + ( + f"[{m},{n}]", + ( + torch.randn(m, n, device=DEVICE, dtype=torch.bfloat16), + torch.randn(m, n, device=DEVICE, dtype=torch.bfloat16), + ), + ) + for m, n in sizes + ] + + +def _softmax_shapes() -> list[tuple[str, tuple[Any, ...]]]: + shapes = [(1024, 256), (1024, 512), (1024, 1024), (1024, 2048), (1024, 4096)] + return [ + ( + f"[{m},{n}]", + (torch.randn(m, n, device=DEVICE, dtype=torch.bfloat16),), + ) + for m, n in shapes + ] + + +# Kernel mappings for TPU/Pallas benchmarks. +# Format: kernel_name -> (module_file, kernel_fn_name, baseline_fn, shapes_fn) +# module_file: filename in examples/ (without .py) +# kernel_fn_name: attribute name of the helion kernel in the module +# baseline_fn: callable that produces reference output (None = call main()) +# shapes_fn: callable returning list of (label, args) pairs (None = call main()) +# +# This list contains only kernels that reliably pass on Pallas/TPU. +KernelMapping = tuple[ + str, + str, + Callable[..., Any] | None, + Callable[[], list[tuple[str, tuple[Any, ...]]]] | None, +] +KERNEL_MAPPINGS: dict[str, KernelMapping] = { + "exp": ("exp", "exp", torch.exp, _exp_shapes), + "add": ("add", "add", torch.add, _add_shapes), + "softmax_two_pass": ( + "softmax", + "softmax_two_pass", + functools.partial(torch.softmax, dim=-1), + _softmax_shapes, + ), + "welford": ("welford", "welford", None, None), + "layer_norm": ("layer_norm", "layer_norm", None, None), +} + + +@dataclass +class ShapeResult: + shape: str + passed: bool + kernel_time_ms: float = 0.0 + baseline_time_ms: float = 0.0 + speedup: float = 0.0 + error: str | None = None + + +@dataclass +class KernelResult: + name: str + passed: bool + kernel_time_ms: float = 0.0 + error: str | None = None + shape_results: list[ShapeResult] = field(default_factory=list) + + +def import_example(module_file: str) -> types.ModuleType: + """Import an example module by filename.""" + module_path = EXAMPLES_DIR / f"{module_file}.py" + spec = importlib.util.spec_from_file_location( + f"examples.{module_file}", module_path + ) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def run_kernel(name: str) -> KernelResult: + """Run a single kernel benchmark: accuracy check + timing vs baseline.""" + if name not in KERNEL_MAPPINGS: + return KernelResult(name=name, passed=False, error=f"Unknown kernel: {name}") + + module_file, kernel_fn_name, baseline_fn, shapes_fn = KERNEL_MAPPINGS[name] + + try: + mod = import_example(module_file) + kernel_fn = getattr(mod, kernel_fn_name) + + # For kernels with None baseline/shapes, call main() directly + # (they have complex setup that's hard to replicate here) + if baseline_fn is None or shapes_fn is None: + start = time.perf_counter() + mod.main() + elapsed = time.perf_counter() - start + return KernelResult( + name=name, + passed=True, + kernel_time_ms=elapsed * 1000, + ) + + shapes = shapes_fn() + all_passed = True + shape_results: list[ShapeResult] = [] + + for label, args in shapes: + print(f" Shape {label}:", file=sys.stderr) + try: + timings = run_example(kernel_fn, baseline_fn, args) + kernel_ms = timings.get("helion", 0.0) + baseline_ms = timings.get("torch", 0.0) + speedup = baseline_ms / kernel_ms if kernel_ms > 0 else 0.0 + shape_results.append( + ShapeResult( + shape=label, + passed=True, + kernel_time_ms=kernel_ms, + baseline_time_ms=baseline_ms, + speedup=speedup, + ) + ) + except Exception as e: + print(f" FAIL: {e}", file=sys.stderr) + shape_results.append( + ShapeResult(shape=label, passed=False, error=str(e)) + ) + all_passed = False + + return KernelResult(name=name, passed=all_passed, shape_results=shape_results) + + except Exception as e: + return KernelResult(name=name, passed=False, error=str(e)) + + +def write_results_json(output: str, results: list[KernelResult]) -> None: + """Write results in the same JSON format as benchmarks/run.py for pytorch benchmark hub.""" + device = os.environ.get("HELION_BACKEND", "pallas") + records: list[dict[str, Any]] = [] + for result in results: + if result.shape_results: + for sr in result.shape_results: + records.append( + { + "benchmark": { + "name": "Helion TPU Benchmark", + "extra_info": {"device": device}, + }, + "model": {"name": result.name}, + "metric": { + "name": "accuracy", + "benchmark_values": [1.0 if sr.passed else 0.0], + }, + "shape": [sr.shape], + } + ) + else: + records.append( + { + "benchmark": { + "name": "Helion TPU Benchmark", + "extra_info": {"device": device}, + }, + "model": {"name": result.name}, + "metric": { + "name": "accuracy", + "benchmark_values": [1.0 if result.passed else 0.0], + }, + "shape": [], + } + ) + if result.kernel_time_ms > 0: + records.append( + { + "benchmark": { + "name": "Helion TPU Benchmark", + "extra_info": {"device": device}, + }, + "model": {"name": result.name}, + "metric": { + "name": "kernel_time_ms", + "benchmark_values": [result.kernel_time_ms], + }, + "shape": [], + } + ) + + if os.path.exists(output): + try: + with open(output) as f: + existing = json.load(f) + if isinstance(existing, list): + records = existing + records + except (OSError, json.JSONDecodeError): + pass + + with open(output, "w") as f: + json.dump(records, f, indent=2) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="TPU/Pallas benchmark runner for Helion examples", + allow_abbrev=False, + ) + parser.add_argument( + "--kernel", + "--op", + type=str, + dest="kernel", + help="Comma-separated list of kernels to run. If not specified, runs all kernels.", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON file path (compatible with pytorch benchmark hub)", + ) + parser.add_argument( + "--list-kernels", + action="store_true", + help="List available kernel names and exit", + ) + args = parser.parse_args() + + if args.list_kernels: + for name in KERNEL_MAPPINGS: + print(name) + return + + if args.kernel: + kernel_names = [k.strip() for k in args.kernel.split(",") if k.strip()] + # Validate + for name in kernel_names: + if name not in KERNEL_MAPPINGS: + print(f"Error: Unknown kernel '{name}'", file=sys.stderr) + print( + f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", + file=sys.stderr, + ) + sys.exit(1) + else: + kernel_names = list(KERNEL_MAPPINGS.keys()) + + print( + f"Running {len(kernel_names)} TPU kernels: {', '.join(kernel_names)}", + file=sys.stderr, + ) + print( + f"HELION_BACKEND={os.environ.get('HELION_BACKEND', '(not set)')}", + file=sys.stderr, + ) + print("=" * 65, file=sys.stderr) + + results: list[KernelResult] = [] + for name in kernel_names: + print(f"\n{'=' * 65}", file=sys.stderr) + print(f"Kernel: {name}", file=sys.stderr) + print(f"{'=' * 65}", file=sys.stderr) + result = run_kernel(name) + results.append(result) + + status = "PASS" if result.passed else "FAIL" + print(f" Status: {status}", file=sys.stderr) + if result.error: + print(f" Error: {result.error}", file=sys.stderr) + if result.shape_results: + for sr in result.shape_results: + sr_status = "PASS" if sr.passed else "FAIL" + print(f" {sr.shape}: {sr_status}", file=sys.stderr) + + # Summary table + print(f"\n{'=' * 75}", file=sys.stderr) + print("Summary", file=sys.stderr) + print(f"{'=' * 75}", file=sys.stderr) + print( + f"{'Kernel':<22} {'Shape':<16} {'Status':<8} {'Helion (ms)':<14} {'Torch (ms)':<14} {'Speedup':<10}", + file=sys.stderr, + ) + print(f"{'-' * 75}", file=sys.stderr) + for result in results: + if result.shape_results: + for sr in result.shape_results: + status = "PASS" if sr.passed else "FAIL" + kernel_str = ( + f"{sr.kernel_time_ms:.4f}" if sr.kernel_time_ms > 0 else "-" + ) + baseline_str = ( + f"{sr.baseline_time_ms:.4f}" if sr.baseline_time_ms > 0 else "-" + ) + speedup_str = f"{sr.speedup:.2f}x" if sr.speedup > 0 else "-" + print( + f"{result.name:<22} {sr.shape:<16} {status:<8} {kernel_str:<14} {baseline_str:<14} {speedup_str:<10}", + file=sys.stderr, + ) + else: + status = "PASS" if result.passed else "FAIL" + time_str = ( + f"{result.kernel_time_ms:.1f}" if result.kernel_time_ms > 0 else "-" + ) + print( + f"{result.name:<22} {'main()':<16} {status:<8} {time_str:<14} {'-':<14} {'-':<10}", + file=sys.stderr, + ) + + passed = sum(1 for r in results if r.passed) + total = len(results) + print(f"{'-' * 75}", file=sys.stderr) + print(f"Total: {passed}/{total} passed", file=sys.stderr) + print(f"{'=' * 75}\n", file=sys.stderr) + + if args.output: + write_results_json(args.output, results) + print(f"Results written to {args.output}", file=sys.stderr) + + if passed < total: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/helion/_testing.py b/helion/_testing.py index 3183074e4..8094c15ae 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -1005,9 +1005,12 @@ def run_example( bwd: bool = False, trace_path: str | None = None, process_group_name: str | None = None, -) -> None: +) -> dict[str, float]: """Run complete example: correctness check + benchmark. + Returns: + Dictionary mapping implementation names to their benchmark times in ms. + Args: kernel_fn: Single kernel function, or dict of {name: function} for multiple kernel variants baseline_fn: Single baseline function or dict of {name: function} for multiple baselines @@ -1187,6 +1190,8 @@ def run_example( print(f"{'=' * 65}\n", file=sys.stderr) + return all_times + def _assert_example_result_close( result: object, From 791e7927cd427fea0bc50b05eeaf566eebfb979e Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 1 Apr 2026 13:41:27 -0700 Subject: [PATCH 02/14] Add top-level permissions block to TPU benchmark workflows --- .github/workflows/benchmark_tpu.yml | 3 +++ .github/workflows/benchmark_tpu_nightly.yml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index de0e6b495..00ea71b11 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -7,6 +7,9 @@ on: required: true type: string +permissions: + contents: read + jobs: benchmark: name: benchmark-tpu-pallas diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml index 07bf6ab50..98cf8dee4 100644 --- a/.github/workflows/benchmark_tpu_nightly.yml +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -15,6 +15,9 @@ on: # low_mem_dropout: ~37% element accuracy mismatch on all configs except block_sizes=[128] default: "exp,add,softmax_two_pass,welford,layer_norm" +permissions: + contents: read + jobs: benchmark-tpu: uses: ./.github/workflows/benchmark_tpu.yml From 323c5653d2c23bfe173d198541906766e02eff0a Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 1 Apr 2026 13:47:02 -0700 Subject: [PATCH 03/14] Add temporary push trigger for CI testing (remove before merge) --- .github/workflows/benchmark_tpu_nightly.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml index 98cf8dee4..e2a50e09b 100644 --- a/.github/workflows/benchmark_tpu_nightly.yml +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -1,6 +1,9 @@ name: Benchmark TPU Nightly on: + push: # TODO: remove before merging — temporary trigger for CI testing + branches: + - yifeixu/tpu-nightly-benchmark schedule: - cron: '0 10 * * *' # Runs at 2 AM PST (10 AM UTC) workflow_dispatch: From 1fbe2932f02da14b484399382d000dbf7b8636c6 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 1 Apr 2026 14:02:48 -0700 Subject: [PATCH 04/14] Fix TPU benchmark CI: align with test.yml setup - gnupg -> gpg (package available on runner) - Use correct secret name (torchtpu-read-key) and repo (google-pytorch/torch_tpu) - Update jax/jaxlib to 0.9.2 matching test.yml --- .github/workflows/benchmark_tpu.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index 00ea71b11..ee3a5984c 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -65,7 +65,7 @@ jobs: --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ --pre \ - 'jax==0.9.1' 'jaxlib==0.9.1' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' + 'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' # Install Bazel if ! command -v bazel &> /dev/null; then sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel @@ -73,21 +73,21 @@ jobs: fi # Install gcloud CLI if not present (needed for Secret Manager) if ! command -v gcloud &> /dev/null; then - sudo apt-get install -y apt-transport-https ca-certificates gnupg curl + sudo apt-get install -y apt-transport-https ca-certificates gpg curl curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee /etc/apt/sources.list.d/google-cloud-sdk.list sudo apt-get update && sudo apt-get install -y google-cloud-cli fi - # Clone torch_tpu via GCP Secret Manager SSH key + # Clone torch_tpu via GCP Secret Manager SSH key (same as pytorch CI) TORCH_TPU_COMMIT=$(cat .github/ci_commit_pins/torch_tpu.txt) set +x gcloud secrets versions access latest \ - --secret="torchtpu-readonly-key" \ + --secret="torchtpu-read-key" \ --project="ml-velocity-actions-testing" > /tmp/torch_tpu_ssh_key set -x chmod 600 /tmp/torch_tpu_ssh_key GIT_SSH_COMMAND="ssh -i /tmp/torch_tpu_ssh_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" \ - git clone git@github.com:google-ml-infra/torch_tpu.git /tmp/torch_tpu + git clone git@github.com:google-pytorch/torch_tpu.git /tmp/torch_tpu rm -f /tmp/torch_tpu_ssh_key cd /tmp/torch_tpu git checkout "${TORCH_TPU_COMMIT}" From 190987bb68540260191590c3c6645d104765f1b3 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Wed, 1 Apr 2026 17:15:00 -0700 Subject: [PATCH 05/14] Add per-kernel timeout and use quick autotuning for TPU benchmark - Add 600s per-kernel timeout using multiprocessing to handle stuck autotuning (native C++ calls can't be interrupted by Python signals) - Set HELION_AUTOTUNE_EFFORT=quick in CI for faster autotuning (30 initial population, 5 generations vs 100/20 for full) - Timeout configurable via HELION_BENCHMARK_KERNEL_TIMEOUT env var --- .github/workflows/benchmark_tpu.yml | 1 + benchmarks/run_tpu.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index ee3a5984c..c221d4cf8 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -17,6 +17,7 @@ jobs: env: HELION_BACKEND: pallas HELION_AUTOTUNE_LOG_LEVEL: INFO + HELION_AUTOTUNE_EFFORT: quick runs-on: linux.google.tpuv7x.1 diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index 6759106af..743459b7a 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -25,6 +25,7 @@ import functools import importlib.util import json +import multiprocessing import os from pathlib import Path import sys @@ -141,7 +142,39 @@ def import_example(module_file: str) -> types.ModuleType: return mod +KERNEL_TIMEOUT = int(os.environ.get("HELION_BENCHMARK_KERNEL_TIMEOUT", "600")) + + +def _run_kernel_impl(name: str, result_queue: multiprocessing.Queue) -> None: # type: ignore[type-arg] + """Run a single kernel in a subprocess (target for multiprocessing).""" + try: + result_queue.put(run_kernel_inner(name)) + except Exception as e: + result_queue.put(KernelResult(name=name, passed=False, error=str(e))) + + def run_kernel(name: str) -> KernelResult: + """Run a single kernel benchmark with a timeout.""" + queue: multiprocessing.Queue = multiprocessing.Queue() # type: ignore[type-arg] + proc = multiprocessing.Process(target=_run_kernel_impl, args=(name, queue)) + proc.start() + proc.join(timeout=KERNEL_TIMEOUT) + if proc.is_alive(): + proc.kill() + proc.join() + return KernelResult( + name=name, + passed=False, + error=f"Timed out after {KERNEL_TIMEOUT}s", + ) + if not queue.empty(): + return queue.get() + return KernelResult( + name=name, passed=False, error="Kernel process exited unexpectedly" + ) + + +def run_kernel_inner(name: str) -> KernelResult: """Run a single kernel benchmark: accuracy check + timing vs baseline.""" if name not in KERNEL_MAPPINGS: return KernelResult(name=name, passed=False, error=f"Unknown kernel: {name}") From 56897df206e31f9cc26b2060ca174b42e1859521 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Thu, 2 Apr 2026 09:19:55 -0700 Subject: [PATCH 06/14] Remove upload-benchmark-results job from TPU benchmark workflow The pytorch/test-infra gather-* actions require pip and nvidia-ml-py, which don't work in a uv venv on TPU runners. Remove the upload job and gather-* steps; keep only the artifact upload for now. --- .github/workflows/benchmark_tpu.yml | 38 --------------------- .github/workflows/benchmark_tpu_nightly.yml | 1 - 2 files changed, 39 deletions(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index c221d4cf8..7f57dc69f 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -25,11 +25,6 @@ jobs: run: shell: bash -l {0} - outputs: - benchmark-metadata: ${{ steps.gather-benchmark-metadata.outputs.benchmark-metadata }} - runners-info: ${{ steps.gather-runners-info.outputs.runners-info }} - dependencies: ${{ steps.gather-dependencies.outputs.dependencies }} - steps: - name: Check out code uses: actions/checkout@v6 @@ -137,41 +132,8 @@ jobs: fi cat "$TEST_REPORTS_DIR/helionbench.json" - - name: Gather benchmark metadata - id: gather-benchmark-metadata - uses: pytorch/test-infra/.github/actions/gather-benchmark-metadata@main - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - venv: .venv/bin/activate - - - name: Gather runners info - id: gather-runners-info - uses: pytorch/test-infra/.github/actions/gather-runners-info@main - with: - venv: .venv/bin/activate - - - name: Gather dependencies - id: gather-dependencies - uses: pytorch/test-infra/.github/actions/gather-dependencies@main - with: - venv: .venv/bin/activate - - name: Upload the benchmark results to GitHub uses: actions/upload-artifact@v7 with: name: benchmark-results-tpu path: test/test-reports - - upload-benchmark-results: - needs: benchmark - uses: pytorch/test-infra/.github/workflows/upload_benchmark_results.yml@main - permissions: - id-token: write - contents: read - with: - benchmark-artifact: benchmark-results-tpu - benchmark-metadata: ${{ needs.benchmark.outputs.benchmark-metadata }} - runners-info: ${{ needs.benchmark.outputs.runners-info }} - dependencies: ${{ needs.benchmark.outputs.dependencies }} - schema-version: v3 - dry-run: false diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml index e2a50e09b..3f590fb38 100644 --- a/.github/workflows/benchmark_tpu_nightly.yml +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -25,7 +25,6 @@ jobs: benchmark-tpu: uses: ./.github/workflows/benchmark_tpu.yml permissions: - id-token: write contents: read with: kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford,layer_norm' }} From 1dc9866ea27d0128e48ffe50a6cde7c25b7c61c7 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Thu, 2 Apr 2026 10:29:45 -0700 Subject: [PATCH 07/14] Fix TPU benchmark timeouts: add --num-shapes, increase timeout, limit generations - Add --num-shapes CLI flag to control how many shapes per kernel (default: all) - Restore full shape lists but use --num-shapes 1 in CI to avoid multiplied autotuning time - Increase per-kernel timeout from 600s to 1200s (quick autotuning on v7 takes ~10min) - Set HELION_AUTOTUNE_MAX_GENERATIONS=2 to further limit autotuning time - Don't fail the job on partial kernel failures (report results for what passed) --- .github/workflows/benchmark_tpu.yml | 12 +++++++----- benchmarks/run_tpu.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index 7f57dc69f..04af1cbab 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -18,6 +18,7 @@ jobs: HELION_BACKEND: pallas HELION_AUTOTUNE_LOG_LEVEL: INFO HELION_AUTOTUNE_EFFORT: quick + HELION_AUTOTUNE_MAX_GENERATIONS: 2 runs-on: linux.google.tpuv7x.1 @@ -111,7 +112,7 @@ jobs: echo "==========================================" # First pass: autotune (populates cache) - python benchmarks/run_tpu.py --kernel "$KERNELS" + python benchmarks/run_tpu.py --kernel "$KERNELS" --num-shapes 1 # Let TPU cool down sleep 1m @@ -124,13 +125,14 @@ jobs: HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 \ python benchmarks/run_tpu.py \ --kernel "$KERNELS" \ + --num-shapes 1 \ --output "$TEST_REPORTS_DIR/helionbench.json" - if [[ ! -s "$TEST_REPORTS_DIR/helionbench.json" ]]; then - echo "helionbench.json is missing or empty" - exit 1 + if [[ -s "$TEST_REPORTS_DIR/helionbench.json" ]]; then + cat "$TEST_REPORTS_DIR/helionbench.json" + else + echo "helionbench.json is missing or empty (some kernels may have failed)" fi - cat "$TEST_REPORTS_DIR/helionbench.json" - name: Upload the benchmark results to GitHub uses: actions/upload-artifact@v7 diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index 743459b7a..32a15d3a0 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -142,7 +142,8 @@ def import_example(module_file: str) -> types.ModuleType: return mod -KERNEL_TIMEOUT = int(os.environ.get("HELION_BENCHMARK_KERNEL_TIMEOUT", "600")) +KERNEL_TIMEOUT = int(os.environ.get("HELION_BENCHMARK_KERNEL_TIMEOUT", "1200")) +NUM_SHAPES: int | None = None # Set from CLI; None means all shapes def _run_kernel_impl(name: str, result_queue: multiprocessing.Queue) -> None: # type: ignore[type-arg] @@ -198,6 +199,8 @@ def run_kernel_inner(name: str) -> KernelResult: ) shapes = shapes_fn() + if NUM_SHAPES is not None: + shapes = shapes[:NUM_SHAPES] all_passed = True shape_results: list[ShapeResult] = [] @@ -313,6 +316,12 @@ def main() -> None: default=None, help="Output JSON file path (compatible with pytorch benchmark hub)", ) + parser.add_argument( + "--num-shapes", + type=int, + default=None, + help="Max number of shapes to benchmark per kernel (default: all)", + ) parser.add_argument( "--list-kernels", action="store_true", @@ -320,6 +329,9 @@ def main() -> None: ) args = parser.parse_args() + global NUM_SHAPES + NUM_SHAPES = args.num_shapes + if args.list_kernels: for name in KERNEL_MAPPINGS: print(name) @@ -410,9 +422,6 @@ def main() -> None: write_results_json(args.output, results) print(f"Results written to {args.output}", file=sys.stderr) - if passed < total: - sys.exit(1) - if __name__ == "__main__": main() From 65a19ceb579a8bf564292dd40e7383785bf5bb53 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Thu, 2 Apr 2026 14:35:36 -0700 Subject: [PATCH 08/14] Fix TPU benchmark deadlock: replace multiprocessing with signal-based timeout The benchmark runner was using multiprocessing.Process (fork) for per-kernel timeouts. On Linux, forking after TPU/JAX initialization causes deadlocks because JAX's internal threads and locks don't survive fork correctly. This caused every kernel to hang for the full timeout (1200s) on CI. Replace with signal.SIGALRM which runs everything in one process, avoiding the fork-after-init issue entirely. --- benchmarks/run_tpu.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index 32a15d3a0..9b8c583ad 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -25,9 +25,9 @@ import functools import importlib.util import json -import multiprocessing import os from pathlib import Path +import signal import sys import time from typing import TYPE_CHECKING @@ -146,33 +146,35 @@ def import_example(module_file: str) -> types.ModuleType: NUM_SHAPES: int | None = None # Set from CLI; None means all shapes -def _run_kernel_impl(name: str, result_queue: multiprocessing.Queue) -> None: # type: ignore[type-arg] - """Run a single kernel in a subprocess (target for multiprocessing).""" - try: - result_queue.put(run_kernel_inner(name)) - except Exception as e: - result_queue.put(KernelResult(name=name, passed=False, error=str(e))) +class _KernelTimeout(Exception): + """Raised by SIGALRM when a kernel exceeds its timeout.""" + + +def _alarm_handler(signum: int, frame: object) -> None: + raise _KernelTimeout def run_kernel(name: str) -> KernelResult: - """Run a single kernel benchmark with a timeout.""" - queue: multiprocessing.Queue = multiprocessing.Queue() # type: ignore[type-arg] - proc = multiprocessing.Process(target=_run_kernel_impl, args=(name, queue)) - proc.start() - proc.join(timeout=KERNEL_TIMEOUT) - if proc.is_alive(): - proc.kill() - proc.join() + """Run a single kernel benchmark with a signal-based timeout. + + Uses SIGALRM instead of multiprocessing to avoid fork-after-TPU-init + deadlocks on Linux. + """ + old_handler = signal.signal(signal.SIGALRM, _alarm_handler) + signal.alarm(KERNEL_TIMEOUT) + try: + return run_kernel_inner(name) + except _KernelTimeout: return KernelResult( name=name, passed=False, error=f"Timed out after {KERNEL_TIMEOUT}s", ) - if not queue.empty(): - return queue.get() - return KernelResult( - name=name, passed=False, error="Kernel process exited unexpectedly" - ) + except Exception as e: + return KernelResult(name=name, passed=False, error=str(e)) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) def run_kernel_inner(name: str) -> KernelResult: From a6183d17771e013d46299a7d8a170153ee78a6b2 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Fri, 3 Apr 2026 22:44:24 -0700 Subject: [PATCH 09/14] Remove layer_norm from TPU benchmark (OOB slice bug, gh#1937) --- .github/workflows/benchmark_tpu_nightly.yml | 5 +++-- benchmarks/run_tpu.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml index 3f590fb38..7371d1b8a 100644 --- a/.github/workflows/benchmark_tpu_nightly.yml +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -16,7 +16,8 @@ on: # rms_norm: InductorLoweringError in torch.mean reduction codegen for fori_loop/emit_pipeline # geglu/swiglu: autotuning takes >15min per kernel (large shape 8x2048x4096), many configs fail to compile # low_mem_dropout: ~37% element accuracy mismatch on all configs except block_sizes=[128] - default: "exp,add,softmax_two_pass,welford,layer_norm" + # layer_norm: OOB slice when reduction_loops doesn't evenly divide the reduction dim (gh#1937) + default: "exp,add,softmax_two_pass,welford" permissions: contents: read @@ -27,4 +28,4 @@ jobs: permissions: contents: read with: - kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford,layer_norm' }} + kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford' }} diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index 9b8c583ad..e69cc1730 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -107,7 +107,6 @@ def _softmax_shapes() -> list[tuple[str, tuple[Any, ...]]]: _softmax_shapes, ), "welford": ("welford", "welford", None, None), - "layer_norm": ("layer_norm", "layer_norm", None, None), } From 5edc93af9b2433d246777623c335085b61327cd7 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Fri, 3 Apr 2026 23:14:37 -0700 Subject: [PATCH 10/14] Use default autotuning settings for TPU benchmark (match GPU) --- .github/workflows/benchmark_tpu.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index 04af1cbab..ed6de205e 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -17,8 +17,6 @@ jobs: env: HELION_BACKEND: pallas HELION_AUTOTUNE_LOG_LEVEL: INFO - HELION_AUTOTUNE_EFFORT: quick - HELION_AUTOTUNE_MAX_GENERATIONS: 2 runs-on: linux.google.tpuv7x.1 From e156b9574ea2ec6db6d14e3771d44fb73df7cb35 Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Sat, 4 Apr 2026 12:47:05 -0700 Subject: [PATCH 11/14] Add 7 more passing kernels to TPU benchmark New kernels: attention, bmm, geglu, grpo_loss, jagged_hstu_attn, low_mem_dropout, swiglu. Total: 11 kernels (up from 4). --- .github/workflows/benchmark_tpu_nightly.yml | 7 ++----- benchmarks/run_tpu.py | 7 +++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/benchmark_tpu_nightly.yml b/.github/workflows/benchmark_tpu_nightly.yml index 7371d1b8a..0914deb44 100644 --- a/.github/workflows/benchmark_tpu_nightly.yml +++ b/.github/workflows/benchmark_tpu_nightly.yml @@ -13,11 +13,8 @@ on: required: false type: string # Excluded kernels: - # rms_norm: InductorLoweringError in torch.mean reduction codegen for fori_loop/emit_pipeline - # geglu/swiglu: autotuning takes >15min per kernel (large shape 8x2048x4096), many configs fail to compile - # low_mem_dropout: ~37% element accuracy mismatch on all configs except block_sizes=[128] # layer_norm: OOB slice when reduction_loops doesn't evenly divide the reduction dim (gh#1937) - default: "exp,add,softmax_two_pass,welford" + default: "exp,add,softmax_two_pass,welford,attention,bmm,geglu,grpo_loss,jagged_hstu_attn,low_mem_dropout,swiglu" permissions: contents: read @@ -28,4 +25,4 @@ jobs: permissions: contents: read with: - kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford' }} + kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford,attention,bmm,geglu,grpo_loss,jagged_hstu_attn,low_mem_dropout,swiglu' }} diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index e69cc1730..b915ca742 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -107,6 +107,13 @@ def _softmax_shapes() -> list[tuple[str, tuple[Any, ...]]]: _softmax_shapes, ), "welford": ("welford", "welford", None, None), + "attention": ("attention", "attention", None, None), + "bmm": ("bmm", "bmm", None, None), + "geglu": ("geglu", "geglu", None, None), + "grpo_loss": ("grpo_loss", "grpo_loss_forward", None, None), + "jagged_hstu_attn": ("jagged_hstu_attn", "jagged_hstu_attn", None, None), + "low_mem_dropout": ("low_mem_dropout", "low_mem_dropout", None, None), + "swiglu": ("swiglu", "swiglu_fwd", None, None), } From 1025e4e3f5a1715e7ab3b4348f9a8044f31fab8e Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Sat, 4 Apr 2026 18:04:43 -0700 Subject: [PATCH 12/14] Set HELION_AUTOTUNE_LOG_LEVEL=DEBUG in TPU benchmark workflow All kernels fail with "Default config failed while computing baseline" but the actual exception is hidden at INFO level. DEBUG will show the generated code and underlying error. --- .github/workflows/benchmark_tpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index ed6de205e..6035a9bb3 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -16,7 +16,7 @@ jobs: env: HELION_BACKEND: pallas - HELION_AUTOTUNE_LOG_LEVEL: INFO + HELION_AUTOTUNE_LOG_LEVEL: DEBUG runs-on: linux.google.tpuv7x.1 From 02989ffa5d7708ddc848dc890740767a5716ffdd Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Sat, 4 Apr 2026 21:01:14 -0700 Subject: [PATCH 13/14] Revert HELION_AUTOTUNE_LOG_LEVEL back to INFO for TPU benchmark --- .github/workflows/benchmark_tpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index 6035a9bb3..ed6de205e 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -16,7 +16,7 @@ jobs: env: HELION_BACKEND: pallas - HELION_AUTOTUNE_LOG_LEVEL: DEBUG + HELION_AUTOTUNE_LOG_LEVEL: INFO runs-on: linux.google.tpuv7x.1 From 5dff8ec7b18715903bc5b837fc61decdcebab31d Mon Sep 17 00:00:00 2001 From: Yifei Xu Date: Sun, 5 Apr 2026 21:57:35 -0700 Subject: [PATCH 14/14] Fix jagged_hstu_attn mapping and use quick autotuning in TPU benchmark MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix wrong kernel function name: jagged_hstu_attn -> _helion_jagged_attention_kernel - Add HELION_AUTOTUNE_EFFORT=quick to CI workflow — full effort times out for 5/11 kernels (welford, attention, geglu, grpo_loss, swiglu) --- .github/workflows/benchmark_tpu.yml | 1 + benchmarks/run_tpu.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/benchmark_tpu.yml b/.github/workflows/benchmark_tpu.yml index ed6de205e..4f39c463b 100644 --- a/.github/workflows/benchmark_tpu.yml +++ b/.github/workflows/benchmark_tpu.yml @@ -17,6 +17,7 @@ jobs: env: HELION_BACKEND: pallas HELION_AUTOTUNE_LOG_LEVEL: INFO + HELION_AUTOTUNE_EFFORT: quick runs-on: linux.google.tpuv7x.1 diff --git a/benchmarks/run_tpu.py b/benchmarks/run_tpu.py index b915ca742..000236559 100644 --- a/benchmarks/run_tpu.py +++ b/benchmarks/run_tpu.py @@ -111,7 +111,7 @@ def _softmax_shapes() -> list[tuple[str, tuple[Any, ...]]]: "bmm": ("bmm", "bmm", None, None), "geglu": ("geglu", "geglu", None, None), "grpo_loss": ("grpo_loss", "grpo_loss_forward", None, None), - "jagged_hstu_attn": ("jagged_hstu_attn", "jagged_hstu_attn", None, None), + "jagged_hstu_attn": ("jagged_hstu_attn", "_helion_jagged_attention_kernel", None, None), "low_mem_dropout": ("low_mem_dropout", "low_mem_dropout", None, None), "swiglu": ("swiglu", "swiglu_fwd", None, None), }