@@ -145,13 +145,16 @@ LogicalResult PtrState::rebuildAsUnsupportedOp(Value operand) {
145145 // Setup state for unsupported operation.
146146 auto indexTy = IndexType::get (operand.getContext ());
147147 auto index0 = IntegerAttr::get (indexTy, APInt (64 , 0 ));
148+ auto index1 = IntegerAttr::get (indexTy, APInt (64 , 1 ));
148149 for (auto size : opShape) {
149- if (size == 1 )
150+ if (size == 1 ) {
150151 offsets.push_back (index0);
151- else
152+ strides.push_back (index0);
153+ } else {
152154 offsets.push_back (operand);
155+ strides.push_back (index1);
156+ }
153157 sizes.push_back (IntegerAttr::get (indexTy, APInt (64 , size)));
154- strides.push_back (index0);
155158 shape.push_back (index0);
156159 }
157160 return success ();
@@ -174,9 +177,10 @@ LogicalResult PtrState::rebuildAsGatherScatter(Value op, int nonContinuousDim) {
174177 // Setup state for nonContinuousDim.
175178 auto indexTy = IndexType::get (op.getContext ());
176179 auto index0 = IntegerAttr::get (indexTy, APInt (64 , 0 ));
180+ auto index1 = IntegerAttr::get (indexTy, APInt (64 , 1 ));
177181
178182 offsets[nonContinuousDim] = op;
179- strides[nonContinuousDim] = index0 ;
183+ strides[nonContinuousDim] = index1 ;
180184 shape[nonContinuousDim] = index0;
181185 return success ();
182186}
@@ -222,43 +226,105 @@ LogicalResult PtrState::addState(const PtrState &lhsState,
222226 addOFRs (lhsState.strides [i], rhsState.strides [i], loc, builder);
223227 strides.push_back (newStride);
224228 } else {
225- // Set stride to 1 when not continuous.
226- strides.push_back (builder.getIndexAttr (1 ));
227- // New offset is offset * stride.
228- auto newLhsOffset = lhsState.offsets [i];
229- auto newRhsOffset = rhsState.offsets [i];
230229 if (isAnalysisingUnstructured) {
231230 assert (!lhsState.hasModulo () && !rhsState.hasModulo () &&
232231 " should not have dimension with modulo when analysing "
233232 " unstructured" );
234- // When the dimension is structured, mul the offset by the stride to
235- // match the stride 1 for non-structured dimensions.
236- // If the dimension is not structured, the offset is already multiplied
237- // by the stride.
238- // If stride is 0 which will happen after
239- // visitOperandExpandDims/visitOperandSplat, we cannot mul which will
240- // get zero and lost the offset.
241- if (lhsState.dimIsStructured (i) && !hasConstZero (lhsState.strides [i])) {
242- auto stride = expandOFRIndex (lhsState.strides [i], lhsState.offsets [i],
243- loc, builder);
244- newLhsOffset = mulOFRs (lhsState.offsets [i], stride, loc, builder);
245- }
246- if (rhsState.dimIsStructured (i) && !hasConstZero (rhsState.strides [i])) {
247- auto stride = expandOFRIndex (rhsState.strides [i], rhsState.offsets [i],
248- loc, builder);
249- newRhsOffset = mulOFRs (rhsState.offsets [i], stride, loc, builder);
250- }
251- // Make sure newLhsOffset and newRhsOffset get same type.
252- if (!lhsState.dimIsStructured (i)) {
253- newRhsOffset =
254- expandOFRIndex (newRhsOffset, newLhsOffset, loc, builder);
233+ if (hasConstZero (lhsState.strides [i]) &&
234+ hasConstZero (lhsState.offsets [i])) {
235+ // If lhs is not for dim i, we can just use rhs's stride and offset.
236+ offsets.push_back (rhsState.offsets [i]);
237+ strides.push_back (rhsState.strides [i]);
238+ } else if (hasConstZero (rhsState.strides [i]) &&
239+ hasConstZero (rhsState.offsets [i])) {
240+ // If rhs is not for dim i, we can just use lhs's stride and offset.
241+ offsets.push_back (lhsState.offsets [i]);
242+ strides.push_back (lhsState.strides [i]);
255243 } else {
256- newLhsOffset =
257- expandOFRIndex (newLhsOffset, newRhsOffset, loc, builder);
244+ OpFoldResult lhsOffset = lhsState.offsets [i];
245+ OpFoldResult rhsOffset = rhsState.offsets [i];
246+ OpFoldResult lhsStride = lhsState.strides [i];
247+ OpFoldResult rhsStride = rhsState.strides [i];
248+ // If stride is 0 which will happen after
249+ // visitOperandExpandDims/visitOperandSplat, we set the stride to 1 to
250+ // mul it with offset.
251+ if (hasConstZero (lhsStride)) {
252+ assert (lhsState.dimIsStructured (i) &&
253+ !rhsState.dimIsStructured (i) &&
254+ " If lhs stride is zero, it must be structured and rhs "
255+ " stride is unstructured" );
256+ lhsStride = builder.getIndexAttr (1 );
257+ }
258+ if (hasConstZero (rhsStride)) {
259+ assert (rhsState.dimIsStructured (i) &&
260+ !lhsState.dimIsStructured (i) &&
261+ " If rhs stride is zero, it must be structured and lhs "
262+ " stride is unstructured" );
263+ rhsStride = builder.getIndexAttr (1 );
264+ }
265+
266+ // If both offset and stride not equal, we merge 2 PtrStates by change
267+ // offset * stride into (offset * stride) * 1 where new offset is
268+ // offset * stride and new stride is set to 1.
269+ // Then we'll have strides equal as 1, and merge them as PtrState with
270+ // same strides.
271+ if (lhsOffset != rhsOffset && lhsStride != rhsStride) {
272+ // Expand offset since unstructured offset has tensor type.
273+ OpFoldResult stride =
274+ expandOFRIndex (lhsStride, lhsOffset, loc, builder);
275+ // new offset = offset * stride
276+ lhsOffset = mulOFRs (lhsOffset, stride, loc, builder);
277+ // Expand offset since unstructured offset has tensor type.
278+ stride = expandOFRIndex (rhsStride, rhsOffset, loc, builder);
279+ // new offset = offset * stride
280+ rhsOffset = mulOFRs (rhsOffset, stride, loc, builder);
281+ // Set both strides to 1.
282+ lhsStride = builder.getIndexAttr (1 );
283+ rhsStride = builder.getIndexAttr (1 );
284+ }
285+
286+ if (lhsStride == rhsStride) {
287+ // For case like lhs_offset * stride + rhs_offset * stride, it is same as
288+ // (lhs_offset + rhs_offset) * stride.
289+ // We can just
290+ // add the offsets and reuse the stride like this:
291+ // offsets[i] = lhsOffset + rhsOffset
292+ // strides[i] = lhsStride
293+ // Expand structured offset since unstructured offset has tensor type.
294+ if (!lhsState.dimIsStructured (i)) {
295+ rhsOffset = expandOFRIndex (rhsOffset, lhsOffset, loc, builder);
296+ } else {
297+ lhsOffset = expandOFRIndex (lhsOffset, rhsOffset, loc, builder);
298+ }
299+ // Add offsets.
300+ offsets.push_back (addOFRs (lhsOffset, rhsOffset, loc, builder));
301+ // Reuse stride.
302+ strides.push_back (lhsStride);
303+ } else {
304+ // Assert that offsets are equal if strides are not equal.
305+ // This is because we are already forcing the strides to be
306+ // equal to 1 earlier for case both offsets and strides not equal.
307+ assert (lhsOffset == rhsOffset &&
308+ " If strides are not equal, offsets must be equal" );
309+ // For case like offset * lhs_stride + offset * rhs_stride, it is same as
310+ // offset * (lhs_stride + rhs_stride).
311+ // We can just
312+ // add the strides and reuse the offset like this:
313+ // offsets[i] = lhsOffset
314+ // strides[i] = lhsStride + rhsStride
315+
316+ // Reuse offsets.
317+ offsets.push_back (lhsOffset);
318+ // Add strides.
319+ strides.push_back (addOFRs (lhsStride, rhsStride, loc, builder));
320+ }
258321 }
259- auto newOffset = addOFRs (newLhsOffset, newRhsOffset, loc, builder);
260- offsets.push_back (newOffset);
261322 } else {
323+ // Set stride to 1 when not continuous.
324+ strides.push_back (builder.getIndexAttr (1 ));
325+ // New offset is offset * stride.
326+ auto newLhsOffset = lhsState.offsets [i];
327+ auto newRhsOffset = rhsState.offsets [i];
262328 // Just propagate the unstructured offset to the result to track the
263329 // unstructured dimension. The real address calculation will be done
264330 // later in the PtrAnalysis::visitOperandAddptr.
@@ -432,13 +498,12 @@ LogicalResult PtrState::mulState(const PtrState &lhsState,
432498 assert (!lhs->dimHasModulo (i) &&
433499 " should not have non-structured dimension with modulo" );
434500 if (isAnalysisingUnstructured) {
435- auto rhsStride =
436- expandOFRIndex (rhs->scalar , lhs->offsets [i], loc, builder);
437501 assert (!lhs->hasModulo () &&
438502 " should not have non-structured dimension with modulo" );
439- OpFoldResult newOffset =
440- mulOFRs (lhs->offsets [i], rhsStride, loc, builder);
441- offsets.push_back (newOffset);
503+ // Keep offsets as is for unstructured dimension.
504+ // The address calculation will be done later in structured to
505+ // memref pass.
506+ offsets.push_back (lhs->offsets [i]);
442507 // Mul the scalar to stride.
443508 OpFoldResult newStride =
444509 mulOFRs (lhs->strides [i], rhs->scalar , loc, builder);
0 commit comments