@@ -360,7 +360,7 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
360360 loc, rewriter);
361361 auto zeroMap = AffineMap::getConstantMap (0 , rewriter.getContext ());
362362 auto loadOp = rewriter.create <affine::AffineLoadOp>(
363- op.getLoc (), sMemRef , zeroMap, std:: nullopt );
363+ op.getLoc (), sMemRef , zeroMap, ValueRange{} );
364364 rewriter.replaceOp (op, loadOp.getResult ());
365365 return success ();
366366 }
@@ -520,7 +520,7 @@ struct StoreConverter : public OpConversionPattern<triton::StoreOp> {
520520 PtrAnalysis::getScalarMemRef (op.getPtr (), ptr, loc, rewriter);
521521 auto zeroMap = AffineMap::getConstantMap (0 , rewriter.getContext ());
522522 rewriter.create <affine::AffineStoreOp>(loc, val, sMemRef , zeroMap,
523- std:: nullopt );
523+ ValueRange{} );
524524 rewriter.eraseOp (op);
525525 return success ();
526526 }
@@ -649,6 +649,28 @@ struct SplatConverter : public OpConversionPattern<triton::SplatOp> {
649649 }
650650};
651651
652+ struct UnsplatConverter : public OpConversionPattern <triton::UnsplatOp> {
653+ using OpConversionPattern::OpConversionPattern;
654+
655+ LogicalResult
656+ matchAndRewrite (triton::UnsplatOp op, OpAdaptor adaptor,
657+ ConversionPatternRewriter &rewriter) const override {
658+ auto tensorType = op.getSrc ().getType ();
659+
660+ // Only generate indices for non-zero rank tensors.
661+ SmallVector<Value, 1 > indices (tensorType.getRank ());
662+ if (indices.size () > 0 ) {
663+ auto zeroIdx =
664+ rewriter.createOrFold <arith::ConstantIndexOp>(op.getLoc (), 0 );
665+ llvm::fill (indices, zeroIdx);
666+ }
667+
668+ rewriter.replaceOpWithNewOp <tensor::ExtractOp>(op, adaptor.getSrc (),
669+ indices);
670+ return success ();
671+ }
672+ };
673+
652674struct BroadcastConverter : public OpConversionPattern <triton::BroadcastOp> {
653675private:
654676 using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
@@ -1397,24 +1419,6 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
13971419 return success ();
13981420 }
13991421
1400- LogicalResult
1401- convertToTensorExtract (triton::ReduceOp op,
1402- typename triton::ReduceOp::Adaptor adaptor,
1403- ConversionPatternRewriter &rewriter) const {
1404- assert (llvm::hasSingleElement (op.getSrcs ()));
1405-
1406- auto returnOp = cast<triton::ReduceReturnOp>(*op.getOps ().begin ());
1407- assert (llvm::hasSingleElement (returnOp.getResult ()));
1408- assert (cast<BlockArgument>(returnOp.getResult ().front ()).getArgNumber () ==
1409- 0 );
1410-
1411- auto source = op.getSrcs ().front ();
1412- auto zeroIdx =
1413- rewriter.createOrFold <arith::ConstantIndexOp>(op.getLoc (), 0 );
1414- rewriter.replaceOpWithNewOp <tensor::ExtractOp>(op, source, zeroIdx);
1415- return success ();
1416- }
1417-
14181422public:
14191423 LogicalResult
14201424 matchAndRewrite (triton::ReduceOp op,
@@ -1431,14 +1435,6 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
14311435 " axis is within "
14321436 " operand's rank" );
14331437
1434- // Unsplat is implemented as a single element, rank 1 reduction where
1435- // single element is yielded immediately. This can be simplified into
1436- // a single element extract.
1437- if (llvm::hasSingleElement (op.getOps ()) && sourceType.getRank () == 1 &&
1438- sourceType.getShape ()[0 ] == 1 ) {
1439- return convertToTensorExtract (op, adaptor, rewriter);
1440- }
1441-
14421438 return convertToLinalgReduce (op, adaptor, rewriter);
14431439 }
14441440};
0 commit comments