forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_fuse_conv_affine.py
More file actions
112 lines (97 loc) · 3.77 KB
/
_fuse_conv_affine.py
File metadata and controls
112 lines (97 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Absorbs affine operation into convolution (best effort):
- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused)
- Add(Mul(Conv)) -> Conv (for all convolutions)
"""
from __future__ import annotations
import numpy as np
import onnx_ir as ir
from onnxscript.rewriter import pattern
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value
class _ConvAffineFusionBase(pattern.RewriteRuleClassBase):
def check(
self,
context,
x: ir.Value,
w: ir.Value,
b: ir.Value,
scale: ir.Value,
offset: ir.Value,
conv_out: ir.Value,
) -> MatchResult:
check_result = MatchResult()
if get_const_value(w) is None:
return check_result.fail("The weight of Conv should be constant")
if get_const_value(b) is None:
return check_result.fail("The bias of Conv should be constant")
if get_singleton_value(scale) is None:
return check_result.fail("Operand for Mul should be constant scalar value")
if get_singleton_value(offset) is None:
return check_result.fail("Operand for Add should be constant scalar value")
return check_result
class AffineConvFusion(_ConvAffineFusionBase):
"""Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)"""
def pattern(
self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value
) -> ir.Value:
return op.Conv(
x * scale + offset,
w,
b,
pads=[0, 0, 0, 0],
_allow_other_attributes=True,
_outputs=["conv_out"],
)
def rewrite(
self,
op: ir.tape.Tape,
x: ir.Value,
w: ir.Value,
b: ir.Value,
scale: ir.Value,
offset: ir.Value,
conv_out: ir.Value,
) -> ir.Value:
scale_value = scale.const_value.numpy()
offset_value = offset.const_value.numpy()
w_value = w.const_value.numpy()
b_value = b.const_value.numpy()
scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled")
offset_bias = ir.tensor(
b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False)
)
offset_bias = op.initializer(offset_bias, b.name + "_offset")
conv_attributes = conv_out.producer().attributes
return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes)
class ConvAffineFusion(_ConvAffineFusionBase):
"""Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)"""
def pattern(
self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value
) -> ir.Value:
return (
op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale
+ offset
)
def rewrite(
self,
op: ir.tape.Tape,
x: ir.Value,
w: ir.Value,
b: ir.Value,
scale: ir.Value,
offset: ir.Value,
conv_out: ir.Value,
) -> ir.Value:
scale_value = scale.const_value.numpy()
offset_value = offset.const_value.numpy()
w_value = w.const_value.numpy()
b_value = b.const_value.numpy()
scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled")
offset_bias = ir.tensor(b_value * scale_value + offset_value)
offset_bias = op.initializer(offset_bias, b.name + "_offset")
conv_attributes = conv_out.producer().attributes
return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes)
affine_conv_fusion_rule = AffineConvFusion().rule()
conv_affine_fusion_rule = ConvAffineFusion().rule()