Skip to content

Commit 1714bad

Browse files
authored
Fix tosa.cast float-to-int rounding to use truncation (RTZ) (#2340)
Upstream tosa-to-linalg lowers tosa.cast float-to-int by inserting math.roundeven before arith.fptosi to implement TOSA's round-to-nearest-even semantics. ONNX and PyTorch instead define float-to-int cast as truncation (round-towards-zero), which is what arith.fptosi already does natively. Since rocMLIR primarily serves ONNX/MIGraphX workloads, restore RTZ semantics without modifying upstream LLVM: - The migraphx.convert lowering in MIGraphXToTosa tags float-to-int tosa.cast ops with the FusedLoc metadata "rocmlir.rtz_cast". Other casts (int-to-float, int-to-int, float-to-float, unsigned) are left untouched to avoid affecting unrelated lowerings or stripping legitimate rounding upstream may insert in the future. - New conversion pass fix-tosa-cast-rounding strips math.roundeven inside linalg.generic when (a) it (or its parent generic) carries the RTZ tag and (b) it exclusively feeds the recognized cast chain (clamp / saturation merge ending at arith.fptosi). Quantization casts (which need RNE) are untouched because they are not tagged. - The pass is wired into the bufferize pipeline immediately after tosa-to-linalg. Tests: - Lit tests cover both the float-clamp and i32-saturation matching paths, plus negatives for untagged roundeven, roundeven outside linalg.generic, quantization, float-to-float convert, and roundeven with extra users. - A CANARY RUN line guards against silent upstream regressions in tosa-to-linalg's choice to emit math.roundeven. - New CPU e2e test verifies actual numerical RTZ behaviour (3.5 -> 3, -3.5 -> -3, etc.).
1 parent 2f4f074 commit 1714bad

13 files changed

Lines changed: 670 additions & 2 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- FixTosaCastRounding.h - Fix tosa.cast rounding -----------*- C++ -*-===//
2+
//
3+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// Copyright (c) 2026 Advanced Micro Devices Inc.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#ifndef MLIR_CONVERSION_FIXTOSACASTROUNDING_FIXTOSACASTROUNDING_H
12+
#define MLIR_CONVERSION_FIXTOSACASTROUNDING_FIXTOSACASTROUNDING_H
13+
14+
#include "mlir/Pass/Pass.h"
15+
#include "llvm/ADT/StringRef.h"
16+
17+
namespace mlir {
18+
19+
#define GEN_PASS_DECL_FIXTOSACASTROUNDINGPASS
20+
#include "mlir/Conversion/RocMLIRPasses.h.inc"
21+
22+
namespace rock {
23+
/// FusedLoc metadata tag used to mark tosa.cast ops that want RTZ rounding.
24+
/// Casts from migraphx.convert carry this tag; casts from quantization do not.
25+
/// Read by the fix-tosa-cast-rounding pass to decide whether to strip the
26+
/// math.roundeven that upstream tosa-to-linalg inserts before arith.fptosi.
27+
constexpr llvm::StringLiteral kRtzCastLocTag("rocmlir.rtz_cast");
28+
} // namespace rock
29+
30+
} // namespace mlir
31+
32+
#endif // MLIR_CONVERSION_FIXTOSACASTROUNDING_FIXTOSACASTROUNDING_H

mlir/include/mlir/Conversion/RocMLIRPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_CONVERSION_ROCMLIRPASSES_H
1111

1212
#include "mlir/Conversion/EmulateFp8ExtTrunc/EmulateFp8ExtTrunc.h"
13+
#include "mlir/Conversion/FixTosaCastRounding/FixTosaCastRounding.h"
1314
#include "mlir/Conversion/LinalgToRock/LinalgToRock.h"
1415
#include "mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h"
1516
#include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h"

mlir/include/mlir/Conversion/RocMLIRPasses.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,40 @@ def ConvertRockToGPUPass : Pass<"convert-rock-to-gpu", "ModuleOp"> {
3333
];
3434
}
3535

36+
//===----------------------------------------------------------------------===//
37+
// FixTosaCastRoundingPass
38+
//===----------------------------------------------------------------------===//
39+
40+
def FixTosaCastRoundingPass
41+
: Pass<"fix-tosa-cast-rounding", "::mlir::func::FuncOp"> {
42+
let summary = "Change tosa.cast float-to-int from round-to-nearest-even to "
43+
"round-towards-zero";
44+
let description = [{
45+
The upstream tosa-to-linalg pass inserts math.roundeven before arith.fptosi
46+
when lowering tosa.cast, implementing TOSA's round-to-nearest-even
47+
semantics. This pass removes those math.roundeven ops so that
48+
arith.fptosi's native truncation (round towards zero) is used instead,
49+
matching ONNX and PyTorch cast semantics. Note that this intentionally
50+
diverges from the TOSA spec; rocMLIR primarily serves ONNX/MIGraphX
51+
workloads where RTZ is the expected behavior.
52+
53+
Only math.roundeven ops that belong to an RTZ-tagged cast lowering are
54+
removed; this preserves round-to-nearest-even for quantization casts.
55+
The tag is the FusedLoc metadata exposed as `mlir::rock::kRtzCastLocTag`
56+
and is set by the migraphx.convert lowering in migraphx-to-tosa. Because
57+
upstream tosa-to-linalg often rewrites the inner roundeven's location,
58+
the tag may end up on the math.roundeven itself, on the parent
59+
linalg.generic's location, or on one of the generic region's *output*
60+
block-argument locations; the pass checks all three. Input block-arg
61+
locations are deliberately ignored to avoid false positives when a
62+
downstream generic merely consumes a previously-tagged cast result.
63+
To stay safe the pass also bails on multi-output generics and on i1
64+
outputs.
65+
}];
66+
let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect",
67+
"math::MathDialect"];
68+
}
69+
3670
//===----------------------------------------------------------------------===//
3771
// RocmlirCustomTosaDecomposePass
3872
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(EmulateFp8ExtTrunc)
2+
add_subdirectory(FixTosaCastRounding)
23
add_subdirectory(MIGraphXToTosa)
34
add_subdirectory(RocmlirCustomTosaDecompose)
45
add_subdirectory(RocmlirCustomTosaToLinalg)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_rocmlir_conversion_library(RocmlirFixTosaCastRounding
2+
FixTosaCastRounding.cpp
3+
4+
DEPENDS
5+
RocMLIRConversionPassIncGen
6+
)
7+
8+
target_link_libraries(RocmlirFixTosaCastRounding
9+
PUBLIC
10+
MLIRIR
11+
MLIRPass
12+
MLIRSupport
13+
MLIRArithDialect
14+
MLIRFuncDialect
15+
MLIRLinalgDialect
16+
MLIRMathDialect
17+
MLIRTransformUtils
18+
)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
//===- FixTosaCastRounding.cpp - Fix tosa.cast rounding -------------------===//
2+
//
3+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// Copyright (c) 2026 Advanced Micro Devices Inc.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
//
11+
// The upstream tosa-to-linalg pass inserts math.roundeven before arith.fptosi
12+
// when lowering tosa.cast from float to integer. This implements TOSA's
13+
// "round to nearest, ties to even" semantics.
14+
//
15+
// However, ONNX and PyTorch define float-to-int cast as truncation (round
16+
// towards zero), which is what arith.fptosi already does natively. Since
17+
// rocMLIR primarily serves ONNX/MIGraphX workloads, this pass removes the
18+
// math.roundeven ops to restore RTZ semantics without modifying the upstream
19+
// LLVM code.
20+
//
21+
//===----------------------------------------------------------------------===//
22+
23+
#include "mlir/Conversion/FixTosaCastRounding/FixTosaCastRounding.h"
24+
#include "mlir/Dialect/Arith/IR/Arith.h"
25+
#include "mlir/Dialect/Func/IR/FuncOps.h"
26+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
27+
#include "mlir/Dialect/Math/IR/Math.h"
28+
#include "mlir/IR/BuiltinTypes.h"
29+
#include "mlir/IR/Location.h"
30+
#include "mlir/IR/PatternMatch.h"
31+
#include "mlir/IR/TypeUtilities.h"
32+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33+
#include "llvm/ADT/STLExtras.h"
34+
35+
namespace mlir {
36+
#define GEN_PASS_DEF_FIXTOSACASTROUNDINGPASS
37+
#include "mlir/Conversion/RocMLIRPasses.h.inc"
38+
} // namespace mlir
39+
40+
using namespace mlir;
41+
42+
namespace {
43+
44+
/// Returns true when `roundeven`'s result participates *exclusively* in the
45+
/// upstream tosa-to-linalg float-to-int cast chain. The chain has two parts:
46+
/// 1. Float clamp: optional `arith.minimumf`/`maximumf` on the rounded
47+
/// value, ending at `arith.fptosi`.
48+
/// 2. Integer saturation merge (i32 case): the rounded value also feeds
49+
/// `arith.cmpf` to produce an i1 mask, which feeds an `arith.select`
50+
/// that picks between an integer saturation constant and the
51+
/// `arith.fptosi` result. The merged i32 then flows to `linalg.yield`.
52+
///
53+
/// We follow both branches and accept `linalg.yield` as a terminal. Removing
54+
/// the `math.roundeven` is safe even for the saturation comparison: at the
55+
/// i32 saturation boundary (|2^31|) every f32 value is already an integer
56+
/// (f32 ULP is >= 1 above 2^23), so `roundeven` is a no-op there and the
57+
/// `arith.cmpf` result is unchanged.
58+
///
59+
/// This is intentionally strict: if any user of a value in the chain is not
60+
/// recognized, we return false (do not strip the `roundeven`) even when a
61+
/// sibling user does reach an `arith.fptosi`. This prevents miscompiles
62+
/// where the rounded value is also consumed by an unrelated op that depends
63+
/// on RNE semantics.
64+
static bool reachesFPToSI(math::RoundEvenOp op) {
65+
SmallVector<Value, 8> worklist(op->getResults());
66+
bool foundFPToSI = false;
67+
while (!worklist.empty()) {
68+
Value v = worklist.pop_back_val();
69+
for (Operation *user : v.getUsers()) {
70+
if (isa<arith::FPToSIOp>(user)) {
71+
foundFPToSI = true;
72+
continue;
73+
}
74+
if (isa<linalg::YieldOp>(user))
75+
continue;
76+
if (isa<arith::MinimumFOp, arith::MaximumFOp, arith::CmpFOp,
77+
arith::SelectOp>(user)) {
78+
for (Value r : user->getResults())
79+
worklist.push_back(r);
80+
continue;
81+
}
82+
return false;
83+
}
84+
}
85+
return foundFPToSI;
86+
}
87+
88+
static bool hasRtzCastLocTag(Location loc) {
89+
if (auto fused = dyn_cast<FusedLoc>(loc))
90+
if (auto meta = dyn_cast_or_null<StringAttr>(fused.getMetadata()))
91+
return meta.getValue() == rock::kRtzCastLocTag;
92+
return false;
93+
}
94+
95+
/// True when this `math.roundeven` is part of an RTZ-tagged
96+
/// `migraphx.convert` lowering. The tag is set on the `tosa.cast` and ends
97+
/// up on the parent `linalg.generic`'s loc and on its output region block
98+
/// argument (the one carved out from the `tensor.empty()` that this cast
99+
/// writes into). Upstream `tosa-to-linalg` may assign the inner
100+
/// `math.roundeven` a different `Location`, so we don't rely on the op's
101+
/// own loc alone.
102+
///
103+
/// We deliberately do NOT scan input block arguments: those inherit the
104+
/// loc of their incoming SSA value, and if that value comes from a
105+
/// previously-tagged cast the tag would propagate forward, causing this
106+
/// pass to wrongly strip an unrelated `math.roundeven` in a downstream
107+
/// generic.
108+
static bool isRtzTaggedCastLowering(math::RoundEvenOp op,
109+
linalg::GenericOp generic) {
110+
if (hasRtzCastLocTag(op->getLoc()) || hasRtzCastLocTag(generic.getLoc()))
111+
return true;
112+
return llvm::any_of(generic.getRegionOutputArgs(), [](BlockArgument arg) {
113+
return hasRtzCastLocTag(arg.getLoc());
114+
});
115+
}
116+
117+
struct RemoveRoundEvenBeforeFPToSI
118+
: public OpRewritePattern<math::RoundEvenOp> {
119+
using OpRewritePattern::OpRewritePattern;
120+
121+
LogicalResult matchAndRewrite(math::RoundEvenOp op,
122+
PatternRewriter &rewriter) const override {
123+
auto generic = op->getParentOfType<linalg::GenericOp>();
124+
if (!generic)
125+
return failure();
126+
127+
// The RTZ-tagged cast lowering corresponds to a single tosa.cast and
128+
// therefore produces exactly one integer output from the generic. Use
129+
// getOutputs() rather than getResultTypes() so this still works in
130+
// buffer semantics (where there are no SSA results).
131+
//
132+
// Bail on multi-output generics (e.g. produced by linalg fusion) to
133+
// avoid stripping a math.roundeven that also feeds a sibling result --
134+
// for instance an i1 yielded directly from arith.cmpf, which would
135+
// silently flip if we removed the rounding.
136+
//
137+
// Bail on i1 outputs as well: ONNX/PyTorch float-to-bool semantics is
138+
// "non-zero" rather than truncation, so removing roundeven would be
139+
// unsafe even if upstream tosa-to-linalg ever emitted it for an i1
140+
// cast. Today MIGraphXToTosa does not tag float-to-i1 casts, but this
141+
// guard is defense-in-depth.
142+
ValueRange outs = generic.getOutputs();
143+
if (outs.size() != 1)
144+
return failure();
145+
Type outElemTy = getElementTypeOrSelf(outs[0].getType());
146+
if (!isa<IntegerType>(outElemTy) || outElemTy.isInteger(1))
147+
return failure();
148+
149+
if (!isRtzTaggedCastLowering(op, generic))
150+
return failure();
151+
152+
if (!reachesFPToSI(op))
153+
return failure();
154+
155+
rewriter.replaceOp(op, op.getOperand());
156+
return success();
157+
}
158+
};
159+
160+
struct FixTosaCastRoundingPass
161+
: public impl::FixTosaCastRoundingPassBase<FixTosaCastRoundingPass> {
162+
using FixTosaCastRoundingPassBase::FixTosaCastRoundingPassBase;
163+
164+
void runOnOperation() override {
165+
RewritePatternSet patterns(&getContext());
166+
patterns.add<RemoveRoundEvenBeforeFPToSI>(&getContext());
167+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
168+
return signalPassFailure();
169+
}
170+
};
171+
172+
} // namespace

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313
#include "mlir/Conversion/MIGraphXToTosa/MIGraphXToTosa.h"
14+
#include "mlir/Conversion/FixTosaCastRounding/FixTosaCastRounding.h"
1415
#include "mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1313,9 +1314,28 @@ ConvertConverter::matchAndRewrite(migraphx::ConvertOp op, OpAdaptor adaptor,
13131314
ROCK_CUSTOMOP_UNSIGNED_CAST, ROCK_CUSTOMOP_DOMAIN_NAME, "",
13141315
adaptor.getInA());
13151316
} else {
1316-
rewriter.replaceOpWithNewOp<tosa::CastOp>(
1317-
op, getTypeConverter()->convertType(op.getResult().getType()),
1317+
// Tag float-to-int casts with RTZ metadata so that fix-tosa-cast-rounding
1318+
// can distinguish them (want truncation) from quantization casts (want
1319+
// RNE). Other casts (int-to-float, int-to-int, float-to-float) don't go
1320+
// through math.roundeven today; tagging them serves no purpose and would
1321+
// risk stripping legitimate rounding if upstream tosa-to-linalg ever
1322+
// inserts it (e.g. for narrowing float-to-float casts).
1323+
//
1324+
// Float-to-bool is excluded explicitly: ONNX/PyTorch bool cast semantics
1325+
// is "non-zero" (not truncation), and upstream tosa-to-linalg lowers it
1326+
// via arith.cmpf une rather than roundeven+fptosi. Tagging it would be
1327+
// misleading and unsafe if upstream ever changes that lowering.
1328+
Location castLoc = op.getLoc();
1329+
if (isa<FloatType>(inputType) && isa<IntegerType>(outputType) &&
1330+
!outputType.isInteger(1))
1331+
castLoc =
1332+
FusedLoc::get(op.getContext(), {op.getLoc()},
1333+
StringAttr::get(op.getContext(), rock::kRtzCastLocTag));
1334+
auto castOp = tosa::CastOp::create(
1335+
rewriter, castLoc,
1336+
getTypeConverter()->convertType(op.getResult().getType()),
13181337
adaptor.getInA());
1338+
rewriter.replaceOp(op, castOp);
13191339
}
13201340
return success();
13211341
}

mlir/lib/Dialect/Rock/Pipelines/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_rocmlir_dialect_library(MLIRRockPipeline
2020
MLIRRockTransforms
2121
MLIRRockUtility
2222
MLIRUBToLLVM
23+
RocmlirFixTosaCastRounding
2324
RocmlirCustomTosaDecompose
2425
RocmlirCustomTosaToLinalg
2526
RocmlirEmulateFp8ExtTrunc

mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ void rock::buildBufferizePipeline(OpPassManager &pm,
9797
/*validationOptions=*/std::nullopt,
9898
/*attachTargetOptions=*/tosaOptions);
9999

100+
// Strip math.roundeven inserted by tosa-to-linalg for RTZ-tagged casts.
101+
auto &castFixPm = pm.nest<func::FuncOp>();
102+
castFixPm.addPass(createFixTosaCastRoundingPass());
103+
100104
// convert named linalg operations into linalg generic
101105
LinalgMorphOpsPassOptions morphOptions;
102106
morphOptions.namedToCategory = false;

0 commit comments

Comments
 (0)