|
19 | 19 | * Normalization precision must be float or double |
20 | 20 | """ |
21 | 21 |
|
22 | | -float_types = [ |
23 | | - ir.DataType.FLOAT, |
24 | | - ir.DataType.FLOAT16, |
25 | | - ir.DataType.BFLOAT16, |
26 | | - ir.DataType.DOUBLE, |
27 | | -] |
28 | | -fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] |
| 22 | +float_types = frozenset( |
| 23 | + [ |
| 24 | + ir.DataType.FLOAT, |
| 25 | + ir.DataType.FLOAT16, |
| 26 | + ir.DataType.BFLOAT16, |
| 27 | + ir.DataType.DOUBLE, |
| 28 | + ] |
| 29 | +) |
| 30 | +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) |
29 | 31 |
|
30 | 32 |
|
31 | 33 | class RmsNormFusion(pattern.RewriteRuleClassBase): |
32 | | - def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): |
33 | | - """ |
34 | | - Args: |
35 | | - name: Name of the rule. |
36 | | - cast_input: Whether to cast input to do the normalization in a different precision. |
37 | | - cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). |
38 | | - """ |
39 | | - super().__init__(name=name) |
40 | | - self._cast_input = cast_input |
41 | | - self._cast_normalized = cast_normalized |
42 | | - |
43 | 34 | def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): |
44 | | - if self._cast_input: |
45 | | - x = op.Cast(x, to=compute_dtype) |
| 35 | + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) |
46 | 36 | x_square = op.Pow(x, 2.0) |
47 | 37 | mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) |
48 | 38 | mean_square_plus_epsilon = op.Add(mean_square, epsilon) |
49 | 39 | rms = op.Sqrt(mean_square_plus_epsilon) |
50 | 40 | reciprocal_rms = op.Reciprocal(rms) |
51 | 41 | normalized = op.Mul(x, reciprocal_rms) |
52 | | - if self._cast_normalized: |
53 | | - normalized = op.Cast(normalized, to=target_dtype) |
| 42 | + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) |
54 | 43 | return op.Mul(scale, normalized) |
55 | 44 |
|
56 | | - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined] |
| 45 | + def check( |
| 46 | + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ |
| 47 | + ) -> pattern.MatchResult: # type: ignore[name-defined] |
57 | 48 | """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" |
58 | 49 | check_result = pattern.MatchResult() |
59 | 50 | # epsilon must be a scalar |
60 | 51 | epsilon_value = _ir_utils.get_singleton_value(epsilon) |
61 | 52 | if not isinstance(epsilon_value, float): # TODO: support other types |
62 | 53 | return check_result.fail("Epsilon is not a float value.", epsilon) |
63 | | - # input and output must be same dtype |
64 | 54 | if x.dtype not in float_types: |
65 | 55 | return check_result.fail("Input is not a float type.", x) |
66 | 56 | if scale.dtype not in float_types: |
67 | 57 | return check_result.fail("Scale is not a float type.", scale) |
68 | | - stash_dtype = compute_dtype.value if self._cast_input else x.dtype |
69 | | - if stash_dtype not in fp_float_types: |
| 58 | + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype |
| 59 | + if self._stash_dtype not in fp_float_types: |
70 | 60 | return check_result.fail("Normalization precision is not a float or double type.") |
| 61 | + # target_dtype is guaranteed to be the same as scale type in a well-typed input |
| 62 | + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. |
| 63 | + # TODO (rama): Consider adding checks to protect against incorrectly typed models: |
71 | 64 | return check_result |
72 | 65 |
|
73 | | - def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): |
74 | | - stash_dtype = compute_dtype.value if self._cast_input else x.dtype |
| 66 | + def rewrite(self, op, x, scale, epsilon, **_): |
75 | 67 | # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. |
76 | 68 | # No need to use com.microsoft domain here; but this is a custom op in ORT. |
77 | 69 | return op.SimplifiedLayerNormalization( |
78 | 70 | x, |
79 | 71 | scale, |
80 | 72 | axis=-1, |
81 | 73 | epsilon=_ir_utils.get_singleton_value(epsilon), |
82 | | - stash_type=stash_dtype, |
| 74 | + stash_type=self._stash_dtype, |
83 | 75 | ) |
84 | 76 |
|
85 | 77 |
|
86 | | -_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True) |
87 | | -_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True) |
88 | | -_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False) |
89 | | -_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False) |
90 | | - |
91 | | -rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3] |
| 78 | +_rule = RmsNormFusion.rule() |
| 79 | +rms_normalization_rules = [_rule] |
92 | 80 | rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) |
93 | 81 |
|
94 | 82 |
|
|
0 commit comments