Skip to content

Commit 76eb4e0

Browse files
author
Xiang Li
committed
Update per comment.
1 parent b06f5f3 commit 76eb4e0

7 files changed

Lines changed: 22 additions & 13 deletions

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,23 @@ static Value applyUnstructuredMask(Operation *op, Value ptr,
5353
}
5454

5555
auto [dim, unstructuredMask] = masks[0];
56-
if (auto scatterPtr =
56+
if (auto gatherScatterPtr =
5757
ptr.getDefiningOp<tts::MakeGatherScatterTensorPtrOp>()) {
58-
if (dim != scatterPtr.getGatherScatterDim()) {
58+
if (dim != gatherScatterPtr.getGatherScatterDim()) {
5959
op->emitRemark("MaskAnalysis failed for unstructured mask dim not equal "
6060
"gather scatter dim");
6161
return nullptr;
6262
}
6363

64-
ptr = builder
65-
.create<tts::MakeGatherScatterTensorPtrOp>(
66-
loc, scatterPtr.getBase(),
67-
scatterPtr.getGatherScatterOffset(), unstructuredMask,
68-
scatterPtr.getGatherScatterDim(), scatterPtr.getSizes(),
69-
scatterPtr.getMixedStrides(), scatterPtr.getMixedOffsets())
70-
.getResult();
71-
64+
ptr =
65+
builder
66+
.create<tts::MakeGatherScatterTensorPtrOp>(
67+
loc, gatherScatterPtr.getBase(),
68+
gatherScatterPtr.getGatherScatterOffset(), unstructuredMask,
69+
gatherScatterPtr.getGatherScatterDim(),
70+
gatherScatterPtr.getSizes(), gatherScatterPtr.getMixedStrides(),
71+
gatherScatterPtr.getMixedOffsets())
72+
.getResult();
7273
} else if (auto structuredPtr = ptr.getDefiningOp<tts::MakeTensorPtrOp>()) {
7374
auto ofrToI32Value = [&](OpFoldResult ofr) {
7475
Value v = dyn_cast<Value>(ofr);

test/Conversion/TritonToStructured/unstructured_mask_2d_kernel.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --cse --canonicalize %s | FileCheck %s
22

3-
// Make sure make_gather_scatter_tptr with unsturctured mask generate correctly from structured ptr with unstructured mask.
3+
// Make sure make_gather_scatter_tptr with unstructuredmask generate correctly from structured ptr with unstructured mask.
4+
// The load is structured ptr, with unstructured mask on dim 0.
5+
// The store is structured ptr, with unstructured mask on dim 1.
46

57
// CHECK-LABEL: tt.func public @generic_mask_2d_kernel(
68
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32},

test/Conversion/TritonToStructured/unstructured_mask_2d_non_continuous_load_kernel.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --cse --canonicalize %s | FileCheck %s
22

3-
// Make sure make_gather_scatter_tptr with unsturctured mask generate correctly from row-structured ptr with unstructured mask.
3+
// Make sure make_gather_scatter_tptr with unstructuredmask generate correctly from row-structured ptr with unstructured mask.
4+
// The load is unstructured ptr on dim 0 and unstructured mask on dim 0.
45

56
// CHECK-LABEL: tt.func public @generic_mask_2d_non_continuous_load_kernel(
67
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32},

test/Conversion/TritonToStructured/unstructured_mask_2d_non_continuous_store_kernel.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --cse --canonicalize %s | FileCheck %s
22

3-
// Make sure make_gather_scatter_tptr with unsturctured mask generate correctly from column-structured ptr with unstructured mask.
3+
// Make sure make_gather_scatter_tptr with unstructuredmask generate correctly from column-structured ptr with unstructured mask.
4+
// The store is unstructured ptr on dim 1 and unstructured mask on dim 1.
45

56
// CHECK-LABEL: tt.func public @generic_mask_2d_non_continuous_store_kernel(
67
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32},

test/Conversion/TritonToStructured/unstructured_mask_3d_kernel.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --cse --canonicalize %s | FileCheck %s
22

33
// Make sure make_gather_scatter_tptr with generic mask generate correctly from structured ptr with unstructured mask.
4+
// The load is structured ptr, with unstructured mask on dim 1.
5+
// The store is structured ptr, with unstructured mask on dim 2.
46

57
// CHECK-LABEL: tt.func public @generic_mask_3d_kernel(
68
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32},

test/Conversion/TritonToStructured/unstructured_mask_3d_non_continuous_load_kernel.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --cse --canonicalize %s | FileCheck %s
22

33
// Make sure make_gather_scatter_tptr with generic mask generate correctly from row-structured ptr with unstructured mask.
4+
// The load is unstructured ptr on dim 1 and unstructured mask on dim 1.
45

56
// CHECK-LABEL: tt.func public @generic_mask_3d_non_continuous_load_kernel(
67
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32},

test/Conversion/TritonToStructured/unstructured_mask_3d_non_continuous_store_kernel.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --cse --canonicalize %s | FileCheck %s
22

33
// Make sure make_gather_scatter_tptr with generic mask generate correctly from column-structured ptr with unstructured mask.
4+
// The store is unstructured ptr on dim 2 and unstructured mask on dim 2.
45

56

67
// CHECK-LABEL: tt.func public @generic_mask_3d_non_continuous_store_kernel(

0 commit comments

Comments
 (0)