-
Notifications
You must be signed in to change notification settings - Fork 137
Epilogue subtiling: store indexing fix, example, and tuple output support in run_example #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
choijon5
wants to merge
6
commits into
main
Choose a base branch
from
epilogue-subtiling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c0c3a59
Add epilogue subtiling pass for tmem_subslice optimization
choijon5 59e520f
Fix subtiled store indexing after config overrides
choijon5 6ec20ac
Add epilogue subtiling example and tuple output support in run_example
choijon5 89e8d02
Add epilogue_subtile/flatten_loops incompatibility check and tests
choijon5 2a40825
Track store positions in indexing list instead of assuming loads-firs…
choijon5 eec168a
Add epilogue subtiling documentation
choijon5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """ | ||
| Epilogue Subtiling Example | ||
| ========================== | ||
|
|
||
| This example demonstrates matmul kernels with heavy epilogues that benefit from | ||
| epilogue subtiling on Blackwell (sm_100+). Epilogue subtiling splits the store | ||
| from ``[BLOCK_M, BLOCK_N]`` into ``SUBTILE_FACTOR x [BLOCK_M, BLOCK_N / SUBTILE_FACTOR]``, | ||
| halving the accumulator shared-memory footprint and enabling an extra pipeline stage. | ||
| """ | ||
|
|
||
| # %% | ||
| # Imports | ||
| # ------- | ||
|
|
||
| # %% | ||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
|
|
||
| import helion | ||
| from helion._testing import DEVICE | ||
| from helion._testing import HALF_DTYPE | ||
| from helion._testing import run_example | ||
| import helion.language as hl | ||
|
|
||
| # %% | ||
| # Kernel 1 -- Matmul + Residual + Bias + GELU + Cast | ||
| # --------------------------------------------------- | ||
| # CUTLASS-style residual + bias + GELU forward epilogue with two | ||
| # fp32 reads (residual, bias) fused into the output tile. | ||
|
|
||
|
|
||
| # %% | ||
| @helion.kernel(static_shapes=True) | ||
| def matmul_bias_residual_gelu_cast( | ||
| x: torch.Tensor, | ||
| w: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| m, k = x.size() | ||
| _, n = w.size() | ||
| out = torch.empty([m, n], dtype=torch.float16, device=x.device) | ||
|
|
||
| for tile_m, tile_n in hl.tile([m, n]): | ||
| acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | ||
| for tile_k in hl.tile(k): | ||
| acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n]) | ||
|
|
||
| val = acc * 1.25 | ||
| val = val + residual[tile_m, tile_n].to(torch.float32) * 0.5 | ||
| val = val + bias[tile_n] | ||
| val = torch.nn.functional.gelu(val) | ||
| out[tile_m, tile_n] = val.to(torch.float16) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| # %% | ||
| # Kernel 2 -- Matmul + Bias + GELU with Auxiliary Output | ||
| # ------------------------------------------------------ | ||
| # cuBLASLt / CUTLASS-style GELU+AUX forward epilogue that writes both | ||
| # the pre-activation (aux) and post-GELU (out) tensors. | ||
|
|
||
|
|
||
| # %% | ||
| @helion.kernel(static_shapes=True) | ||
| def matmul_bias_gelu_aux( | ||
| x: torch.Tensor, | ||
| w: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| m, k = x.size() | ||
| _, n = w.size() | ||
| out = torch.empty([m, n], dtype=torch.float16, device=x.device) | ||
| aux = torch.empty([m, n], dtype=torch.float16, device=x.device) | ||
|
|
||
| for tile_m, tile_n in hl.tile([m, n]): | ||
| acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | ||
| for tile_k in hl.tile(k): | ||
| acc = torch.addmm(acc, x[tile_m, tile_k], w[tile_k, tile_n]) | ||
|
|
||
| pre = acc * 1.25 | ||
| pre = pre + bias[tile_n] | ||
| aux[tile_m, tile_n] = pre.to(torch.float16) | ||
| out[tile_m, tile_n] = torch.nn.functional.gelu(pre).to(torch.float16) | ||
|
|
||
| return out, aux | ||
|
|
||
|
|
||
| # %% | ||
| # Verification | ||
| # ------------ | ||
|
|
||
|
|
||
| # %% | ||
| def check(m: int, k: int, n: int) -> None: | ||
| x = torch.randn([m, k], device=DEVICE, dtype=HALF_DTYPE) | ||
| w = torch.randn([k, n], device=DEVICE, dtype=HALF_DTYPE) | ||
| bias = torch.randn([n], device=DEVICE, dtype=HALF_DTYPE) | ||
| residual = torch.randn([m, n], device=DEVICE, dtype=HALF_DTYPE) | ||
|
|
||
| def baseline_residual_gelu( | ||
| x: torch.Tensor, | ||
| w: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| acc = x.float() @ w.float() | ||
| val = acc * 1.25 + residual.float() * 0.5 + bias.float() | ||
| return torch.nn.functional.gelu(val).half() | ||
|
|
||
| run_example( | ||
| matmul_bias_residual_gelu_cast, | ||
| baseline_residual_gelu, | ||
| (x, w, bias, residual), | ||
| ) | ||
|
|
||
| def baseline_gelu_aux( | ||
| x: torch.Tensor, | ||
| w: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| acc = x.float() @ w.float() | ||
| pre = acc * 1.25 + bias.float() | ||
| return torch.nn.functional.gelu(pre).half(), pre.half() | ||
|
|
||
| run_example( | ||
| matmul_bias_gelu_aux, | ||
| baseline_gelu_aux, # pyrefly: ignore[bad-argument-type] | ||
| (x, w, bias), | ||
| ) | ||
|
|
||
|
|
||
| # %% | ||
| # Main | ||
| # ---- | ||
|
|
||
|
|
||
| # %% | ||
| def main() -> None: | ||
| check(8192, 8192, 8192) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need an example file here, shouldn't the autotuner find it automatically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The examples are just to illustrate to users that epilogue subtiling is supported and demonstrate cases where it can provide benefits.