@@ -112,11 +112,15 @@ struct PtrState {
112112
113113 // Process addition of two PtrStates.
114114 LogicalResult addState (const PtrState &lhsState, const PtrState &rhsState,
115- Operation *op, OpBuilder &builder);
115+ bool isAnalysisingUnstructured, Operation *op,
116+ OpBuilder &builder);
116117
117118 // Process multiplication of two PtrStates
118119 LogicalResult mulState (const PtrState &lhsState, const PtrState &rhsState,
119- Operation *op, OpBuilder &builder);
120+ bool isAnalysisingUnstructured, Operation *op,
121+ OpBuilder &builder);
122+
123+ LogicalResult mergeUnstructuredState (const PtrState &other, Operation *op);
120124
121125 tts::MakeTensorPtrOp createTTSMakeTensorPtrOp (OpBuilder &builder,
122126 Location loc);
@@ -147,6 +151,41 @@ class PtrAnalysis {
147151
148152 DenseSet<Value> maybeStructuredArgs;
149153 const bool enableMakeGatherScatterTensorPtr;
154+ // If false, PtrAnalysis will analysis structured ptr while only identify
155+ // unstructured ptr.
156+ // If true, PtrAnalysis will caclulate strides and offsets for
157+ // unstructured pointers. This is used to support gather/scatter access.
158+ // The default mode is false. Only set to true when caclulating
159+ // unstructured pointers for gather/scatter access.
160+ // The reason to have different mode is to support case like:
161+ //
162+ // ptr + (row_offsets[:,None] % mod_offset + some_number) +
163+ // row_indices[:None]
164+ //
165+ // (row_offsets[:,None] % mod_offset + some_number) is structured and
166+ // has modulo.
167+ // row_indices[:, None] is unstructured.
168+ // When visiting the add operation, we need to apply the modulo to
169+ // (row_offsets[:,None] % mod_offset + some_number).
170+ // But we don't have the information about how to apply the modulo.
171+ //
172+ // To simplify the analysis, we do the work in two modes:
173+ // 1. Analyze to identify the unstructured pointers.
174+ // 2. Analyze to calculate the strides and offsets for unstructured pointers.
175+ // In mode 1, isAnalysisingUnstructured is set to false, so we only
176+ // identify the unstructured pointers and do not calculate the strides and
177+ // offsets.
178+ // When visit the operand again to calculate the offsets and strides for the
179+ // unstructured state, we'll set isAnalysisingUnstructured to true.
180+ // This means that we switched to mode 2 now and are analyzing the
181+ // unstructured pointers and calculating the strides and offsets for them. In
182+ // mode 2, we know that the pointer is unstructured, so we can just use the
183+ // value of arith::RemSIOp as offset directly. Once the analysis is done,
184+ // we'll switch back to mode 1.
185+ //
186+ // Note that this is might be a temporary solution, and we may need to
187+ // revisit this in the future to support more complex cases.
188+ bool isAnalysisingUnstructured = false ;
150189
151190public:
152191 PtrAnalysis (bool enableMakeGatherScatterTensorPtr)
0 commit comments