Skip to content

Commit 40bb175

Browse files
ProExpertProgxinyu-intelchaojun-zhangLuka Govedič
authored
[vLLM IR] 1/N Implement IR skeleton and rms_norm op (vllm-project#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
1 parent 0fab52f commit 40bb175

49 files changed

Lines changed: 2177 additions & 265 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.buildkite/test_areas/kernels.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ group: Kernels
22
depends_on:
33
- image-build
44
steps:
5+
- label: vLLM IR Tests
6+
timeout_in_minutes: 10
7+
working_dir: "/vllm-workspace/"
8+
source_file_dependencies:
9+
- vllm/ir
10+
- vllm/kernels
11+
commands:
12+
- pytest -v -s tests/ir
13+
- pytest -v -s tests/kernels/ir
14+
515
- label: Kernels Core Operation Test
616
timeout_in_minutes: 75
717
source_file_dependencies:

.github/CODEOWNERS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
1414
/vllm/model_executor/model_loader @22quinn
1515
/vllm/model_executor/layers/batch_invariant.py @yewentao256
16+
/vllm/ir @ProExpertProg
17+
/vllm/kernels/ @ProExpertProg @tjtanaa
18+
/vllm/kernels/helion @ProExpertProg @zou3519
1619
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
1720
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
1821
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
7477
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
7578
/tests/evals @mgoin @vadiklyutiy
7679
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
80+
/tests/kernels/ir @ProExpertProg @tjtanaa
7781
/tests/models @DarkLight1337 @ywang96
7882
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
7983
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety

tests/compile/backend.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import depyf
1010
from torch import fx
11-
from torch._ops import OpOverload
11+
from torch._ops import OpOverload, OpOverloadPacket
1212
from torch.fx._utils import lazy_format_graph_code
1313

1414
from vllm.compilation.passes.fx_utils import find_op_nodes
@@ -90,7 +90,9 @@ def post_pass(self, graph: fx.Graph):
9090
# assign by reference, will reflect the final state of the graph
9191
self.final_graph = graph
9292

93-
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
93+
def check_before_ops(
94+
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
95+
):
9496
for op in ops:
9597
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
9698
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
@@ -99,13 +101,19 @@ def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
99101
if fully_replaced:
100102
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
101103

102-
def check_after_ops(self, ops: Sequence[OpOverload]):
104+
def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
103105
for op in ops:
104106
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
105107
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
106108
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
107109
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
108110

109-
def op_count(self, op: OpOverload, before=False) -> int:
111+
def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
110112
graph = self.graph_pre_pass if before else self.graph_post_pass
111113
return len(list(find_op_nodes(op, graph)))
114+
115+
def print_graphs(self):
116+
print("=== Graph before custom passes ===")
117+
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
118+
print("=== Graph after custom passes ===")
119+
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)

tests/compile/fusions_e2e/test_tp1_quant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions(
9999
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
100100
model_kwargs["load_format"] = "dummy"
101101
model_kwargs["max_model_len"] = 1024
102+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
103+
102104
compilation_config = dict(
103105
use_inductor_graph_partition=inductor_graph_partition,
104106
custom_ops=custom_ops.split(","),
@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions(
166168
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
167169
model_kwargs["load_format"] = "dummy"
168170
model_kwargs["max_model_len"] = 1024
171+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
169172

170173
compilation_config = dict(
171174
use_inductor_graph_partition=inductor_graph_partition,

tests/compile/fusions_e2e/test_tp2_ar_rms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions(
6868
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
6969
model_kwargs["load_format"] = "dummy"
7070
model_kwargs["max_model_len"] = 1024
71+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
7172

7273
compilation_config = dict(
7374
use_inductor_graph_partition=inductor_graph_partition,
@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions(
128129
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
129130
model_kwargs["load_format"] = "dummy"
130131
model_kwargs["max_model_len"] = 1024
132+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
131133

132134
compilation_config = dict(
133135
use_inductor_graph_partition=inductor_graph_partition,
@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions(
182184
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
183185
model_kwargs["load_format"] = "dummy"
184186
model_kwargs["max_model_len"] = 1024
187+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
185188

186189
compilation_config = dict(
187190
use_inductor_graph_partition=inductor_graph_partition,

tests/compile/fusions_e2e/test_tp2_async_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
5858
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
5959
model_kwargs["load_format"] = "dummy"
6060
model_kwargs["max_model_len"] = 1024
61+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
6162

6263
compilation_config = dict(
6364
use_inductor_graph_partition=inductor_graph_partition,
@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
121122
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
122123
model_kwargs["load_format"] = "dummy"
123124
model_kwargs["max_model_len"] = 1024
125+
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
124126

125127
compilation_config = dict(
126128
use_inductor_graph_partition=inductor_graph_partition,

tests/compile/passes/distributed/test_sequence_parallelism.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tests.utils import TestFP8Layer, multi_gpu_test
1010
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
1111
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
12-
from vllm.compilation.passes.fx_utils import find_auto_fn
1312
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
1413
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
1514
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
@@ -86,13 +85,14 @@ def ops_in_model_after(self):
8685
]
8786

8887
def ops_in_model(self):
89-
if RMSNorm.enabled():
90-
return [
91-
torch.ops._C.rms_norm.default,
88+
return (
89+
[torch.ops.vllm_ir.rms_norm]
90+
+ [
9291
torch.ops._C.fused_add_rms_norm.default,
9392
]
94-
else:
95-
return []
93+
if RMSNorm.enabled()
94+
else []
95+
)
9696

9797

9898
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
321321
assert backend.op_count(op, before=False) == 4
322322

323323
for op in model.ops_in_model():
324-
find_auto_fn(backend.graph_post_pass.nodes, op)
324+
assert backend.op_count(op, before=False) > 0

tests/compile/passes/ir/__init__.py

Whitespace-only changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from torch import nn
6+
7+
import vllm.kernels # noqa: F401 to register kernels
8+
from vllm import ir
9+
from vllm.compilation.passes.ir.lowering_pass import (
10+
VllmIRLoweringPass,
11+
)
12+
from vllm.config import get_current_vllm_config
13+
from vllm.ir import ops
14+
from vllm.platforms import current_platform
15+
16+
from ...backend import TestBackend
17+
18+
19+
class Model(nn.Module):
20+
def __init__(self, hidden_size=16, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
self.hidden_size = hidden_size
23+
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
24+
25+
def forward(self, x):
26+
x1 = x + 4.0
27+
x2 = ops.rms_norm(x1, self.weight, 1e-5)
28+
x3 = x2 * 5.0
29+
# no weight
30+
x4 = ops.rms_norm(x3, None, 1e-5)
31+
x5 = x4 / 2.0
32+
# dispatch to native due to variance_size parameter
33+
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
34+
return x6 + 3.0
35+
36+
37+
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
38+
def test_lowering_rms_norm(rms_provider, default_vllm_config):
39+
torch.set_default_device(current_platform.device_type)
40+
41+
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
42+
backend = TestBackend(lowering_pass)
43+
backend_unlowered = TestBackend()
44+
45+
model = Model()
46+
x = torch.randn(8, 16, dtype=torch.bfloat16)
47+
with (
48+
ops.rms_norm.set_priority([rms_provider, "native"]),
49+
ir.enable_torch_wrap(True),
50+
):
51+
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
52+
compiled_unlowered_model = torch.compile(
53+
model, backend=backend_unlowered, fullgraph=True
54+
)
55+
output = compiled_model(x)
56+
output_unlowered = compiled_unlowered_model(x)
57+
58+
selected = lowering_pass.selected_impls["rms_norm"]
59+
assert len(selected) == 3
60+
assert selected["rms_norm"] == rms_provider
61+
assert selected["rms_norm_1"] == rms_provider
62+
assert selected["rms_norm_2"] == "native"
63+
64+
# Compiled function guards on global value, avoid recompilation
65+
with ir.enable_torch_wrap(True):
66+
output2 = compiled_model(x)
67+
68+
torch.testing.assert_close(output_unlowered, output)
69+
torch.testing.assert_close(output_unlowered, output2)

tests/compile/passes/test_fusion.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
import vllm.config
9+
import vllm.ir.ops
910
import vllm.plugins
1011
from tests.compile.backend import TestBackend
1112
from tests.utils import TestBlockFP8Layer, TestFP8Layer
@@ -51,7 +52,6 @@
5152

5253
FP8_DTYPE = current_platform.fp8_dtype()
5354

54-
RMS_OP = torch.ops._C.rms_norm.default
5555
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
5656

5757
# Kernel and group_shape combinations: (kernel, group_shape)
@@ -246,10 +246,8 @@ def ops_in_model_after(self):
246246
]
247247

248248
def ops_in_model_before_partial(self):
249-
return (
250-
[RMS_OP, RMS_ADD_OP]
251-
if self.enable_rms_norm_custom_op
252-
else [torch.ops.aten.rsqrt]
249+
return [torch.ops.vllm_ir.rms_norm] + (
250+
[RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
253251
)
254252

255253

@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
340338
),
341339
)
342340

343-
with vllm.config.set_current_vllm_config(vllm_config):
341+
with (
342+
vllm.config.set_current_vllm_config(vllm_config),
343+
vllm_config.kernel_config.ir_op_priority.set_priority(),
344+
):
344345
# Setup device before model creation
345346
torch.set_default_device("cuda")
346347
torch.set_default_dtype(dtype)
@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
370371
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
371372
if not enable_rms_norm_custom_op:
372373
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
373-
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
374-
assert n_add_nodes(backend.graph_pre_pass) == 7
374+
# rms_norm is IR, not included
375+
# 6 = 3x2 (3xRMS_ADD, 2 each)
376+
assert n_add_nodes(backend.graph_pre_pass) == 6
375377
assert n_add_nodes(backend.graph_post_pass) == 2
376378

377379

0 commit comments

Comments
 (0)