Skip to content

Commit 93ade94

Browse files
authored
[hexagon-mlir] update repository (#67)
Update hexagon-mlir with new features. --------- Signed-off-by: mabsar <mabsar@qti.qualcommm.com>
1 parent dc4667f commit 93ade94

240 files changed

Lines changed: 25654 additions & 1688 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitmodules

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[submodule "triton"]
2-
path = triton
3-
url = https://github.com/triton-lang/triton.git
2+
path = triton
3+
url = https://github.com/triton-lang/triton.git
44
[submodule "triton_shared"]
5-
path = triton_shared
6-
url = https://github.com/microsoft/triton-shared
5+
path = triton_shared
6+
url = https://github.com/facebookincubator/triton-shared.git

ci/apply_patches.sh

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/bin/bash
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause.
5+
# For more license information:
6+
# https://github.com/qualcomm/hexagon-mlir/LICENSE.txt
7+
#
8+
set -Eeuox pipefail
9+
10+
# If the CI image accidentally has /etc/gitconfig as a *directory*, Git will fail
11+
# when reading the system config. We detect that here and warn that this script
12+
# will ignore the system gitconfig by using GIT_CONFIG_NOSYSTEM=1 on all git calls.
13+
if [ -d /etc/gitconfig ]; then
14+
echo "WARNING: Detected /etc/gitconfig is a DIRECTORY; system gitconfig will be ignored via GIT_CONFIG_NOSYSTEM=1 for all git operations in this script."
15+
fi
16+
17+
SCRIPT_DIR="$(readlink -f "$(dirname "$0")")"
18+
HEXAGON_MLIR_ROOT="$(readlink -f "$SCRIPT_DIR/../")"
19+
TRITON_ROOT="$HEXAGON_MLIR_ROOT/triton"
20+
21+
# Apply a patch if it isn't already applied (stateless; no marker file).
22+
apply_patch_if_needed() {
23+
local repo_dir="$1" # e.g., "$TRITON_ROOT" or "$HEXAGON_MLIR_ROOT/triton_shared"
24+
local patch_file="$2" # e.g., ".../patches/triton/third_party_triton.patch"
25+
26+
if [ ! -f "$patch_file" ]; then
27+
echo "WARNING: Patch file not found at $patch_file"
28+
return 0
29+
fi
30+
31+
echo "Checking/applying patch: $patch_file in $repo_dir"
32+
pushd "$repo_dir" >/dev/null
33+
34+
# 1) If the reverse applies, the patch is already present — skip.
35+
if GIT_CONFIG_NOSYSTEM=1 git apply --reverse --check "$patch_file" >/dev/null 2>&1; then
36+
echo "Patch already applied (reverse-check passed): $patch_file — skipping."
37+
popd >/dev/null
38+
return 0
39+
fi
40+
41+
# 2) Try to apply the patch directly.
42+
# git apply is the source of truth to avoid TOCTOU races.
43+
if GIT_CONFIG_NOSYSTEM=1 git apply "$patch_file"; then
44+
echo "Patch applied successfully: $patch_file"
45+
popd >/dev/null
46+
return 0
47+
fi
48+
49+
# 3) Neither forward nor reverse apply cleanly -> inconsistent state / conflicts.
50+
echo "ERROR: Patch neither applies nor is already applied: $patch_file"
51+
echo "----- git apply --check (verbose) output -----"
52+
GIT_CONFIG_NOSYSTEM=1 git apply --check -v "$patch_file" || true
53+
echo "----------------------------------------------"
54+
popd >/dev/null
55+
exit 1
56+
}
57+
58+
# -----------------------------------------------------------------------------
59+
# Apply patches (drop the marker basename arg; keep the order if patches depend on one another)
60+
# -----------------------------------------------------------------------------
61+
62+
# -----------------------------------------------------------------------------
63+
# triton_shared patches
64+
# -----------------------------------------------------------------------------
65+
# Triton shared patch to update the API for compatibility with the latest LLVM
66+
TRITON_SHARED_API_UPDATE_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_3_6_triton.patch"
67+
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_API_UPDATE_PATCH_FILE"
68+
69+
# Triton shared patch on Pointer Analysis
70+
TRITON_SHARED_POINTER_ANALYSIS_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_ptr_analysis.patch"
71+
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_POINTER_ANALYSIS_PATCH_FILE"
72+
73+
# Triton shared patch on split pointers
74+
TRITON_SHARED_SPLIT_DIM_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_split_dim.patch"
75+
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_SPLIT_DIM_PATCH_FILE"
76+
77+
# Triton shared patch to handle canonicalization pattern of Max with NaN propagation
78+
TRITON_SHARED_MAX_NAN_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_max_nan.patch"
79+
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_MAX_NAN_PATCH_FILE"
80+
81+
# -----------------------------------------------------------------------------
82+
# Triton patches (build third-party backends + NVVM ReductionKind compatibility)
83+
# -----------------------------------------------------------------------------
84+
# Triton patch to get around NVVM ReductionKind compatibility issue
85+
TRITON_NVVM_COMPATIBILITY_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton/nvvm_reduction_kind_compatibility.patch"
86+
apply_patch_if_needed "$TRITON_ROOT" "$TRITON_NVVM_COMPATIBILITY_PATCH_FILE"
87+
88+
# Add libdevice sigmoid support to triton
89+
TRITON_LIBDEVICE_SIGMOID_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton/libdevice_sigmoid.patch"
90+
apply_patch_if_needed "$TRITON_ROOT" "$TRITON_LIBDEVICE_SIGMOID_PATCH_FILE"

ci/hexagon-mlir-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ pybind11
66
scipy
77
lit
88
wheel
9-
transformers==4.52.4
9+
transformers==4.52.4

ci/setup_submodules.sh

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,49 @@
44
# SPDX-License-Identifier: BSD-3-Clause.
55
# For more license information:
66
# https://github.com/qualcomm/hexagon-mlir/LICENSE.txt
7-
#
7+
8+
89
set -euo pipefail
910

1011
REPO_ROOT="$(git rev-parse --show-toplevel)"
11-
echo "Configuring git for submodules"
12+
echo "Configuring git submodules"
1213
cd "${REPO_ROOT}"
13-
# Add submodules if missing
14-
if [ ! -d "triton" ]; then
15-
git submodule add --force https://github.com/triton-lang/triton.git triton
16-
cd triton
17-
echo "Applying qcom specific patches to triton"
18-
git checkout e44bd1c83c1c3e8deac7c4f02683cfb3cc395c8b
19-
git apply "${REPO_ROOT}/third_party_software/patches/triton/third_party_triton.patch"
20-
fi
2114

22-
cd "${REPO_ROOT}"
23-
if [ ! -d "triton_shared" ]; then
24-
git submodule add --force https://github.com/microsoft/triton-shared triton_shared
25-
cd triton_shared
26-
git checkout 2b728ad97bc02af821a0805b09075838911d4c19
27-
echo "Applying qcom specific patches to triton_shared"
28-
git apply "${REPO_ROOT}/third_party_software/patches/triton_shared/max_with_nan_propagation.patch"
29-
git apply "${REPO_ROOT}/third_party_software/patches/triton_shared/tt_shared_split_dim.patch"
30-
fi
15+
# Ensure existing submodules are initialized
16+
git submodule update --init
17+
18+
add_and_checkout() {
19+
local name="$1"
20+
local url="$2"
21+
local commit="$3"
3122

23+
cd "${REPO_ROOT}"
24+
if [ ! -d "${REPO_ROOT}/${name}" ]; then
25+
echo "Adding submodule ${name}"
26+
git submodule add --force "${url}" "${name}"
27+
fi
28+
29+
echo "Checking out ${name} at ${commit}"
30+
cd "${REPO_ROOT}/${name}"
31+
git fetch origin
32+
git checkout "${commit}"
33+
}
34+
35+
add_and_checkout \
36+
triton \
37+
https://github.com/triton-lang/triton.git \
38+
df38505e451a1541555379bcf378be9e8c00545c
39+
40+
add_and_checkout \
41+
triton_shared \
42+
https://github.com/facebookincubator/triton-shared.git \
43+
0614763d270ec0eacba9d5d8283cdff6bedb03c8
44+
45+
cd "${REPO_ROOT}"
46+
echo "Applying qcom specific patches to triton_shared"
47+
bash "${REPO_ROOT}/ci/apply_patches.sh" || {
48+
echo "ERROR: Failed while applying patches"
49+
exit 1
50+
}
51+
3252
echo "Submodules triton and triton_shared initialized and patched successfully."

qcom_hexagon_backend/backend/compiler.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import re
1313
import tempfile
14+
from dataclasses import replace
1415
from pathlib import Path
1516
import subprocess
1617
from typing import Any, Dict, no_type_check
@@ -30,21 +31,36 @@
3031
# mypy compiler.py hexagon_executor.py hexagon_launcher_base.py torch_mlir_hexagon_launcher.py triton_hexagon_launcher.py --follow-untyped-imports --check-untyped-defs
3132

3233

33-
def _get_triton_shared_opt_path() -> str:
34-
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
34+
def _get_triton_shared_opt_path(device_type: str) -> str:
35+
path = os.getenv(
36+
"TRITON_SHARED_OPT_PATH",
37+
"",
38+
)
3539
if path == "":
3640
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
37-
return path
3841

42+
bin_path = Path(path).resolve()
43+
44+
if not bin_path.exists() or not bin_path.is_file():
45+
raise FileNotFoundError(
46+
f"Could not find 'triton-shared-opt' at expected location: {bin_path}"
47+
)
48+
if not os.access(bin_path, os.X_OK):
49+
raise PermissionError(
50+
f"'triton-shared-opt' exists but is not executable: {bin_path}"
51+
)
52+
53+
return str(bin_path)
3954

40-
def ttir_to_ttsharedir(mod):
55+
56+
def ttir_to_ttsharedir(mod: str, options):
4157
# Get Triton-MLIR as string
4258
ttir_code = str(mod)
4359
with tempfile.TemporaryDirectory() as tmpdir:
4460
src_path = os.path.join(tmpdir, "tt.mlir")
4561
dst_path = os.path.join(tmpdir, "ttshared.mlir")
4662
Path(src_path).write_text(ttir_code)
47-
triton_shared_opt_path = _get_triton_shared_opt_path()
63+
triton_shared_opt_path = _get_triton_shared_opt_path(options.device_type)
4864
subprocess.check_call(
4965
[
5066
triton_shared_opt_path,
@@ -90,6 +106,10 @@ def ttsharedir_to_obj(mod: str, options, metadata={}) -> bytes:
90106
options_map = {k: str(v) for k, v in (options.__dict__).items()}
91107
# TODO: Move setting benchmarking iterations when additional stage for shared object creation is part of compilation pipeline.
92108
metadata["iterations"] = options_map["iterations"]
109+
metadata["scratch"] = options_map["scratch"]
110+
metadata["enableMultiThreading"] = options_map["enableMultiThreading"]
111+
metadata["enableThreadedDispatch"] = options_map["enableThreadedDispatch"]
112+
metadata["enableLWP"] = options_map["enableLWP"]
93113

94114
# TODO: The lowering pipeline needs to be refactored similar to other Triton backends to
95115
# have a dynamic pipeline filtered by options with each pass represented by a pybind function.
@@ -138,13 +158,31 @@ def hash(self):
138158
return f"{version}-{self.target}"
139159

140160
def parse_options(self, opts) -> Any:
141-
assert self.target.backend == "hexagon"
142-
args = {
143-
k: opts[k]
144-
for k in HexagonOptions.__dataclass_fields__.keys()
145-
if k in opts
146-
}
147-
return HexagonOptions(**args)
161+
assert self.target.backend == "hexagon"
162+
args = {
163+
k: opts[k] for k in HexagonOptions.__dataclass_fields__.keys() if k in opts
164+
}
165+
hexagon_opts = HexagonOptions(**args)
166+
167+
# When external VTCM scratch is enabled (scratch > 0), automatically
168+
# configure flags for correct SPMD behavior at compile time so that
169+
# both the compiled IR and the generated wrapper are consistent:
170+
# - Disable enableConvertToHexagonmem to prevent VTCMPool from
171+
# concurrently trying to allocate VTCM alongside the external pool.
172+
# - Disable enableHexagonmemCopyToDMA to avoid DMA/thread conflicts.
173+
# - Enable enableThreadedDispatch so instances run in parallel on
174+
# real qurt hardware threads (handled at wrapper generation time).
175+
# Users do not need to set these flags manually when scratch > 0.
176+
if hexagon_opts.scratch > 0:
177+
hexagon_opts = replace(
178+
hexagon_opts,
179+
enableMultiThreading=False,
180+
enableConvertToHexagonmem=False,
181+
enableVTCMTiling=True,
182+
enableThreadedDispatch=True,
183+
)
184+
185+
return hexagon_opts
148186

149187
@staticmethod
150188
def make_ttir(mod, metadata, opt):
@@ -161,13 +199,13 @@ def make_ttir(mod, metadata, opt):
161199
passes.common.add_symbol_dce(pm)
162200
passes.ttir.add_loop_unroll(pm)
163201
passes.common.add_cse(pm)
164-
pm.run(mod)
202+
pm.run(mod, "make_ttir")
165203
return mod
166204

167205
# May need to add num_warps
168206
def add_stages(self, stages, options, language):
169207
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
170-
stages["ttsharedir"] = lambda src, metadata: ttir_to_ttsharedir(src)
208+
stages["ttsharedir"] = lambda src, metadata: ttir_to_ttsharedir(src, options)
171209
if options.htp_kernel_gen:
172210
if options.target_artifact == "llir":
173211
stages["llir"] = lambda src, metadata: ttsharedir_to_llir(
@@ -190,6 +228,8 @@ def add_stages(self, stages, options, language):
190228
src, options, metadata
191229
)
192230
else: # Default compilation pipeline
231+
assert options.device_type == "hexagon"
232+
193233
stages["o"] = lambda src, metadata: ttsharedir_to_obj(
194234
src, options, metadata
195235
)
@@ -240,6 +280,10 @@ def pack_metadata(self, metadata):
240280
metadata.name,
241281
metadata.return_types,
242282
metadata.iterations,
283+
metadata.scratch,
284+
metadata.enableMultiThreading,
285+
metadata.enableThreadedDispatch,
286+
metadata.enableLWP,
243287
)
244288

245289
def get_module_map(self) -> Dict[str, ModuleType]:

qcom_hexagon_backend/backend/driver.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ def __call__(self, *args, **kwargs):
4848
"_mlir_ciface_" if len(return_profs) > 0 else ""
4949
) + pack_metadata[6]
5050
iterations = pack_metadata[8]
51+
compiled_scratch = pack_metadata[9] if len(pack_metadata) > 9 else None
52+
compiled_enable_multithreading = (
53+
pack_metadata[10] if len(pack_metadata) > 10 else None
54+
)
55+
compiled_enable_threaded_dispatch = (
56+
pack_metadata[11] if len(pack_metadata) > 11 else None
57+
)
58+
compiled_enable_lwp = pack_metadata[12] if len(pack_metadata) > 12 else None
5159
num_fixed_args = 9
5260
inputs_with_constants = list(args[num_fixed_args:])
5361
inputs = [
@@ -65,7 +73,17 @@ def __call__(self, *args, **kwargs):
6573
"""
6674
)
6775
self.launcher._exec_kernel(
68-
kernel_llir, iterations, func_name, inputs, return_profs, launch_grid
76+
kernel_llir,
77+
iterations,
78+
func_name,
79+
inputs,
80+
return_profs,
81+
launch_grid,
82+
compiled_scratch=compiled_scratch,
83+
compiled_enable_multithreading=compiled_enable_multithreading,
84+
compiled_enable_threaded_dispatch=compiled_enable_threaded_dispatch,
85+
compiled_enable_lwp=compiled_enable_lwp,
86+
runtime_options=kwargs,
6987
)
7088
# TODO: There seems to be no way to propogate the call returns upward, because
7189
# - The call result is not used by the caller
@@ -132,3 +150,19 @@ def get_active_torch_device(self):
132150

133151
def get_current_stream(self, device):
134152
return None
153+
154+
def get_device_interface(self):
155+
import torch
156+
157+
return torch.cpu
158+
159+
def get_empty_cache_for_benchmark(self):
160+
import torch
161+
162+
device = "cpu"
163+
# 256MB cache
164+
cache_size = 256 * 1024 * 1024
165+
return torch.empty(int(cache_size // 4), dtype=torch.int, device=device)
166+
167+
def clear_cache(self, cache):
168+
cache.zero_()

0 commit comments

Comments
 (0)