[Pallas] Fix ZeroDivisionError in block spec for int64 1D tensors #10865
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Tests | |
| on: | |
| pull_request: | |
| push: | |
| branches: | |
| - main | |
| - release/* | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} | |
| cancel-in-progress: true | |
| jobs: | |
| load-matrix: | |
| runs-on: ubuntu-latest | |
| outputs: | |
| matrix: ${{ steps.set-matrix.outputs.matrix }} | |
| steps: | |
| - name: Checkout repository | |
| uses: actions/checkout@v6 | |
| - name: Load matrix from file | |
| id: set-matrix | |
| run: | | |
| matrix=$(cat .github/matrix.json | jq -c .) | |
| echo "matrix=$matrix" >> $GITHUB_OUTPUT | |
| test: | |
| needs: load-matrix | |
| strategy: | |
| fail-fast: false | |
| matrix: ${{ fromJSON(needs.load-matrix.outputs.matrix) }} | |
| name: test-${{ matrix.runtime-version }}-py${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.backend }}-${{ matrix.alias }} | |
| container: ${{ matrix.image != '' && fromJSON(format('{{"image":"{0}","options":"{1}"}}', matrix.image, matrix.container-options)) || '' }} | |
| runs-on: ${{ matrix.runner }} | |
| defaults: | |
| run: | |
| shell: bash -l {0} | |
| steps: | |
| - name: Run NVIDIA command | |
| if: startsWith(matrix.image, 'nvidia') | |
| run: | | |
| echo "Detected NVIDIA image" | |
| nvidia-smi || echo "nvidia-smi not found" | |
| - name: Run ROCm command | |
| if: startsWith(matrix.image, 'rocm') | |
| run: | | |
| echo "Detected ROCm image" | |
| rocminfo || echo "rocminfo not found" | |
| - name: Check out code | |
| uses: actions/checkout@v6 | |
| - name: Install system dependencies | |
| run: | | |
| set -eux | |
| SUDO=$(command -v sudo 2>/dev/null || true) | |
| $SUDO apt-get update | |
| $SUDO apt-get install -y libdw1 curl wget git pkg-config zlib1g-dev build-essential | |
| - name: Install NVSHMEM | |
| if: contains(matrix.alias, 'distributed') | |
| run: | | |
| set -euxo pipefail | |
| GPU_COUNT=$(nvidia-smi -L | wc -l) | |
| if [ "$GPU_COUNT" -ne 4 ]; then | |
| echo "Error: Expected 4 GPUs but found $GPU_COUNT" | |
| exit 1 | |
| fi | |
| curl -L https://raw.githubusercontent.com/pytorch/pytorch/main/.ci/docker/common/install_cuda.sh -o install_cuda.sh | |
| chmod +x install_cuda.sh | |
| source install_cuda.sh | |
| install_nvshmem 13 3.4.5 | |
| - name: Install uv | |
| uses: astral-sh/setup-uv@v7 | |
| with: | |
| python-version: ${{ matrix.python-version }} | |
| enable-cache: true | |
| - name: Create virtual environment | |
| run: | | |
| uv venv --python ${{ matrix.python-version }} | |
| - name: Get current month | |
| id: date | |
| run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT | |
| - name: Cache dependencies | |
| id: cache | |
| uses: actions/cache@v5 | |
| with: | |
| path: | | |
| ~/.cache/uv | |
| ~/.venv | |
| key: ${{ matrix.python-version }}-${{ matrix.runtime-version }}-${{ matrix.pytorch-version }}-${{ hashFiles('.github/workflows/test.yml') }}-${{ steps.date.outputs.month }} | |
| - name: Install PyTorch | |
| run: | | |
| source .venv/bin/activate | |
| if [[ "${{ matrix.pytorch-version }}" == "pytorch-2.9" ]]; then | |
| # Install stable 2.9 from test channel | |
| uv pip install -U "torch==2.9.*" --index-url https://download.pytorch.org/whl/${{ matrix.runtime-version }} | |
| elif [[ "${{ matrix.runtime-version }}" == "tpu" ]]; then | |
| # TPU: install CPU-only PyTorch nightly (torch_tpu provides TPU backend) | |
| uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu | |
| else | |
| # Default to nightly | |
| if [[ "${{ matrix.runtime-version }}" == "cu128" ]]; then | |
| # Install nvidia-nvshmem-cu12 from cu129 index (missing on cu128) | |
| uv pip install -U --pre nvidia-nvshmem-cu12 --index-url https://download.pytorch.org/whl/nightly/cu129 | |
| fi | |
| uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.runtime-version }} | |
| fi | |
| - name: Install Triton | |
| if: matrix.backend == 'tileir' || (matrix.backend == 'triton' && steps.cache.outputs.cache-hit != 'true' && matrix.pytorch-version != 'pytorch-2.9') | |
| run: | | |
| set -x | |
| source .venv/bin/activate | |
| apt-get update | |
| apt-get install -y git | |
| apt-get install -y clang-20 clang++-20 zlib1g-dev | |
| export CC=clang-20 | |
| export CXX=clang++-20 | |
| mkdir -p /tmp/$USER | |
| cd /tmp/$USER | |
| uv pip uninstall triton pytorch-triton || true | |
| rm -rf triton/ || true | |
| if [[ "${{ matrix.backend }}" == "tileir" ]]; then | |
| git clone --recursive -b main https://github.com/triton-lang/Triton-to-tile-IR.git triton | |
| else | |
| git clone https://github.com/triton-lang/triton.git triton | |
| if [[ "${{ matrix.python-version }}" == "3.14" ]]; then | |
| # Pin Python 3.14 nightly to known-good Triton revision until backend detection is fixed upstream. | |
| git -C triton checkout 77a13369 | |
| else | |
| git -C triton checkout 9844da95 | |
| fi | |
| fi | |
| cd triton/ | |
| uv pip install -r python/requirements.txt | |
| MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install . | |
| cd /tmp/$USER | |
| rm -rf triton/ | |
| python -c "import triton; print(f'Triton version: {triton.__version__}')" | |
| - name: Pin networkx for Python 3.14 | |
| if: matrix.python-version == '3.14' | |
| run: | | |
| source .venv/bin/activate | |
| uv pip install networkx==2.8.8 | |
| - name: Install Helion | |
| run: | | |
| source .venv/bin/activate | |
| uv pip install setuptools ninja | |
| SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]' | |
| python -c "import helion; print(helion.__name__)" | |
| - name: Install TPU dependencies (Pallas) | |
| if: matrix.alias == 'tpu' | |
| run: | | |
| set -euxo pipefail | |
| source .venv/bin/activate | |
| uv pip install \ | |
| --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ | |
| --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ | |
| --pre \ | |
| 'jax==0.9.2' 'jaxlib==0.9.2' 'libtpu==0.0.37' 'tpu-info==0.7.1' 'jaxtyping' 'frozendict' | |
| # Install Bazel | |
| if ! command -v bazel &> /dev/null; then | |
| sudo curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-linux-amd64 -o /usr/local/bin/bazel | |
| sudo chmod +x /usr/local/bin/bazel | |
| fi | |
| # Install gcloud CLI if not present (needed for Secret Manager) | |
| if ! command -v gcloud &> /dev/null; then | |
| sudo apt-get install -y apt-transport-https ca-certificates gnupg curl | |
| curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg | |
| echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee /etc/apt/sources.list.d/google-cloud-sdk.list | |
| sudo apt-get update && sudo apt-get install -y google-cloud-cli | |
| fi | |
| # Clone torch_tpu via GCP Secret Manager SSH key (same as pytorch CI) | |
| TORCH_TPU_COMMIT=$(cat .github/ci_commit_pins/torch_tpu.txt) | |
| set +x | |
| gcloud secrets versions access latest \ | |
| --secret="torchtpu-read-key" \ | |
| --project="ml-velocity-actions-testing" > /tmp/torch_tpu_ssh_key | |
| set -x | |
| chmod 600 /tmp/torch_tpu_ssh_key | |
| GIT_SSH_COMMAND="ssh -i /tmp/torch_tpu_ssh_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" \ | |
| git clone git@github.com:google-pytorch/torch_tpu.git /tmp/torch_tpu | |
| rm -f /tmp/torch_tpu_ssh_key | |
| cd /tmp/torch_tpu | |
| git checkout "${TORCH_TPU_COMMIT}" | |
| # Build torch_tpu wheel | |
| export TORCH_SOURCE=$(python -c "import torch; import os; print(os.path.dirname(os.path.dirname(torch.__file__)))") | |
| export SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") | |
| bazel build -c opt //ci/wheel:torch_tpu_wheel --config=helion_public_caching_readwrite --define WHEEL_VERSION=0.1.0 --define TORCH_SOURCE=local --action_env=PYTHONPATH=$TORCH_SOURCE:$SITE_PACKAGES --action_env=JAX_PLATFORMS=cpu | |
| uv pip install bazel-bin/ci/wheel/*.whl | |
| cd - | |
| rm -rf /tmp/torch_tpu | |
| # Verify | |
| python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')" | |
| - name: Install CUTLASS CuTe DSL | |
| if: matrix.backend == 'cute' | |
| run: | | |
| source .venv/bin/activate | |
| SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install .'[cute-cu12]' | |
| - name: CUDA Compute Check | |
| if: startsWith(matrix.image, 'nvidia') | |
| run: | | |
| source .venv/bin/activate | |
| python -c " | |
| import torch, sys | |
| assert torch.cuda.is_available(), 'FATAL: CUDA not available' | |
| n = torch.cuda.device_count() | |
| assert n > 0, 'FATAL: No CUDA devices found' | |
| print(f'CUDA devices: {n}') | |
| for i in range(n): | |
| dev = torch.device('cuda', i) | |
| a = torch.randn(256, 256, device=dev) | |
| b = (a @ a).sum().item() | |
| print(f' Device {i} ({torch.cuda.get_device_name(i)}): OK') | |
| print(f'All {n} devices healthy') | |
| " | |
| - name: Run Tests | |
| run: | | |
| set -o pipefail | |
| source .venv/bin/activate | |
| uv pip install pytest-xdist | |
| # Conditionally enable ref-eager and golden-accept/dtype-assert test modes | |
| if [[ "${{ matrix.dtype-asserts }}" == "true" ]]; then export HELION_DEBUG_DTYPE_ASSERTS=1; fi | |
| if [[ "${{ matrix.expecttest-accept }}" == "true" ]]; then export EXPECTTEST_ACCEPT=1; fi | |
| if [[ "${{ matrix.ref-eager }}" == "true" ]]; then export HELION_INTERPRET=1; fi | |
| if [[ "${{ matrix.backend }}" == "tileir" ]]; then export ENABLE_TILE=1; fi | |
| export HELION_BACKEND=${{ matrix.backend }} | |
| # -rf: print failed tests | |
| # --timeout: max allowed time for each test | |
| PARALLEL="-n4" | |
| if [[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]]; then | |
| TEST_PATH="test/test_examples_dist.py" | |
| EXTRA_FLAGS="-rs" | |
| elif [[ "${{ matrix.alias }}" == "tpu" ]]; then | |
| TEST_PATH="." | |
| EXTRA_FLAGS="--ignore=test/test_examples_dist.py" | |
| PARALLEL="" | |
| else | |
| TEST_PATH="." | |
| EXTRA_FLAGS="--ignore=test/test_examples_dist.py" | |
| fi | |
| # For distributed tests, fail if any test is skipped, failed, or has an error | |
| SKIP_CHECK=$([[ "${{ contains(matrix.alias, 'distributed') }}" == "true" ]] && echo "! grep -qE '(SKIPPED|FAILED|ERROR)'" || echo "cat > /dev/null") | |
| pytest $PARALLEL -rf --timeout=60 $EXTRA_FLAGS $TEST_PATH | tee >(eval $SKIP_CHECK) | |
| test-notebooks: | |
| name: test-notebooks-cu128-py3.12-pytorch-2.9-a10g | |
| container: | |
| image: nvidia/cuda:12.8.1-devel-ubuntu24.04 | |
| options: --gpus all | |
| runs-on: linux.g5.4xlarge.nvidia.gpu | |
| defaults: | |
| run: | |
| shell: bash -l {0} | |
| steps: | |
| - name: Run NVIDIA command | |
| run: | | |
| echo "Detected NVIDIA image" | |
| nvidia-smi || echo "nvidia-smi not found" | |
| - name: Check out code | |
| uses: actions/checkout@v6 | |
| - name: Install uv | |
| uses: astral-sh/setup-uv@v7 | |
| with: | |
| python-version: "3.12" | |
| enable-cache: true | |
| - name: Create virtual environment | |
| run: | | |
| uv venv --python 3.12 | |
| - name: Install pip in venv | |
| run: | | |
| source .venv/bin/activate | |
| uv pip install pip | |
| - name: Get current month | |
| id: date | |
| run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT | |
| - name: Cache dependencies | |
| id: cache | |
| uses: actions/cache@v5 | |
| with: | |
| path: | | |
| ~/.cache/uv | |
| ~/.venv | |
| key: notebooks-3.12-cu128-${{ hashFiles('.github/workflows/test.yml') }}-${{ steps.date.outputs.month }} | |
| - name: Install notebook execution tools | |
| run: | | |
| source .venv/bin/activate | |
| # Install jupyter for executing notebooks | |
| uv pip install jupyter nbconvert pytest numpy "nbclient<0.10" | |
| - name: Run Notebook Tests | |
| run: | | |
| source .venv/bin/activate | |
| # Execute notebook using jupyter nbconvert | |
| # The notebook's subprocess pip install will install torch and helion | |
| jupyter nbconvert --to notebook --execute --inplace \ | |
| --ExecutePreprocessor.timeout=600 \ | |
| notebooks/softmax.ipynb |