Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions .github/workflows/benchmark_tpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
name: Benchmark TPU

on:
workflow_call:
inputs:
kernels:
required: true
type: string

permissions:
contents: read

jobs:
benchmark:
name: benchmark-tpu-pallas

env:
HELION_BACKEND: pallas
HELION_AUTOTUNE_LOG_LEVEL: INFO
HELION_AUTOTUNE_EFFORT: quick

runs-on: linux.google.tpuv7x.1

defaults:
run:
shell: bash -l {0}

steps:
- name: Check out code
uses: actions/checkout@v6

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"

- name: Install uv
uses: astral-sh/setup-uv@v7

- name: Create virtual environment
run: |
uv venv --python 3.12

- name: Install PyTorch (CPU nightly)
run: |
source .venv/bin/activate
uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

- name: Install Helion
run: |
source .venv/bin/activate
uv pip install setuptools ninja
SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]'
python -c "import helion; print(helion.__name__)"

- name: Install TPU dependencies (Pallas)
run: |
set -euxo pipefail
source .venv/bin/activate
uv pip install \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--pre \
'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict'
# Install Bazel
if ! command -v bazel &> /dev/null; then
sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel
sudo chmod +x /usr/local/bin/bazel
fi
# Install gcloud CLI if not present (needed for Secret Manager)
if ! command -v gcloud &> /dev/null; then
sudo apt-get install -y apt-transport-https ca-certificates gpg curl
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee /etc/apt/sources.list.d/google-cloud-sdk.list
sudo apt-get update && sudo apt-get install -y google-cloud-cli
fi
# Clone torch_tpu via GCP Secret Manager SSH key (same as pytorch CI)
TORCH_TPU_COMMIT=$(cat .github/ci_commit_pins/torch_tpu.txt)
set +x
gcloud secrets versions access latest \
--secret="torchtpu-read-key" \
--project="ml-velocity-actions-testing" > /tmp/torch_tpu_ssh_key
set -x
chmod 600 /tmp/torch_tpu_ssh_key
GIT_SSH_COMMAND="ssh -i /tmp/torch_tpu_ssh_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" \
git clone [email protected]:google-pytorch/torch_tpu.git /tmp/torch_tpu
rm -f /tmp/torch_tpu_ssh_key
cd /tmp/torch_tpu
git checkout "${TORCH_TPU_COMMIT}"
# Build torch_tpu wheel
export TORCH_SOURCE=$(python -c "import torch; import os; print(os.path.dirname(os.path.dirname(torch.__file__)))")
export SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
bazel build -c opt //ci/wheel:torch_tpu_wheel --config=helion_public_caching_readwrite --define WHEEL_VERSION=0.1.0 --define TORCH_SOURCE=local --action_env=PYTHONPATH=$TORCH_SOURCE:$SITE_PACKAGES --action_env=JAX_PLATFORMS=cpu
uv pip install bazel-bin/ci/wheel/*.whl
cd -
rm -rf /tmp/torch_tpu
# Verify
python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')"

- name: Run TPU Benchmark
run: |
source .venv/bin/activate

TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"

KERNELS="${{ inputs.kernels }}"
echo "=========================================="
echo "TPU Benchmark: autotuning pass"
echo "Kernels: $KERNELS"
echo "=========================================="

# First pass: autotune (populates cache)
python benchmarks/run_tpu.py --kernel "$KERNELS" --num-shapes 1

# Let TPU cool down
sleep 1m

echo "=========================================="
echo "TPU Benchmark: cache-hit verification pass"
echo "=========================================="

# Second pass: verify cache hits and record results
HELION_PRINT_OUTPUT_CODE=1 HELION_ASSERT_CACHE_HIT=1 \
python benchmarks/run_tpu.py \
--kernel "$KERNELS" \
--num-shapes 1 \
--output "$TEST_REPORTS_DIR/helionbench.json"

if [[ -s "$TEST_REPORTS_DIR/helionbench.json" ]]; then
cat "$TEST_REPORTS_DIR/helionbench.json"
else
echo "helionbench.json is missing or empty (some kernels may have failed)"
fi

- name: Upload the benchmark results to GitHub
uses: actions/upload-artifact@v7
with:
name: benchmark-results-tpu
path: test/test-reports
28 changes: 28 additions & 0 deletions .github/workflows/benchmark_tpu_nightly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Benchmark TPU Nightly

on:
push: # TODO: remove before merging — temporary trigger for CI testing
branches:
- yifeixu/tpu-nightly-benchmark
schedule:
- cron: '0 10 * * *' # Runs at 2 AM PST (10 AM UTC)
workflow_dispatch:
inputs:
kernels:
description: 'Comma-separated list of kernels to benchmark'
required: false
type: string
# Excluded kernels:
# layer_norm: OOB slice when reduction_loops doesn't evenly divide the reduction dim (gh#1937)
default: "exp,add,softmax_two_pass,welford,attention,bmm,geglu,grpo_loss,jagged_hstu_attn,low_mem_dropout,swiglu"

permissions:
contents: read

jobs:
benchmark-tpu:
uses: ./.github/workflows/benchmark_tpu.yml
permissions:
contents: read
with:
kernels: ${{ github.event.inputs.kernels || 'exp,add,softmax_two_pass,welford,attention,bmm,geglu,grpo_loss,jagged_hstu_attn,low_mem_dropout,swiglu' }}
Loading
Loading