Skip to content

Commit 166ff39

Browse files
authored
add common subexpression elimination (#44386)
1 parent 64b61fc commit 166ff39

5 files changed

Lines changed: 523 additions & 0 deletions

File tree

paddle/fluid/framework/ir/CMakeLists.txt

100755100644
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
115115
pass_library(dense_fc_to_sparse_pass inference)
116116
pass_library(dense_multihead_matmul_to_sparse_pass inference)
117117
pass_library(generate_pass DEPS pass_desc_proto)
118+
pass_library(common_subexpression_elimination_pass inference)
118119
target_link_libraries(generate_pass pass_desc_proto)
119120

120121
if(WITH_TENSORRT)
@@ -326,6 +327,10 @@ cc_test(
326327
test_generate_pass_cc
327328
SRCS generate_pass_tester.cc
328329
DEPS generate_pass pass_desc_proto)
330+
cc_test(
331+
test_common_subexpression_elimination_pass_cc
332+
SRCS common_subexpression_elimination_pass_tester.cc
333+
DEPS common_subexpression_elimination_pass)
329334
cc_test(
330335
test_delete_dropout_pass_cc
331336
SRCS delete_dropout_op_pass_test.cc
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/common_subexpression_elimination_pass.h"
16+
#include <string>
17+
#include <type_traits>
18+
19+
#include "paddle/fluid/framework/framework.pb.h"
20+
#include "paddle/fluid/framework/ir/graph_helper.h"
21+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
22+
#include "paddle/fluid/framework/ir/node.h"
23+
#include "paddle/fluid/framework/op_version_registry.h"
24+
#include "paddle/fluid/framework/type_defs.h"
25+
#include "paddle/phi/core/enforce.h"
26+
#include "paddle/utils/variant.h"
27+
28+
namespace {
29+
30+
std::string NodeTypeToString(paddle::framework::ir::Node::Type type) {
31+
if (type == paddle::framework::ir::Node::Type::kOperation) {
32+
return "kOperation";
33+
} else {
34+
return "kVariable";
35+
}
36+
}
37+
38+
const std::unordered_set<std::string> commutative_operators{"mul",
39+
"bitwise_and",
40+
"bitwise_or",
41+
"equal_all",
42+
"equal",
43+
"not_equal",
44+
"logical_and",
45+
"logical_or",
46+
"elementwise_max",
47+
"elementwise_fmax",
48+
"elementwise_min",
49+
"elementwise_fmin",
50+
"elementwise_mul",
51+
"elementwise_add",
52+
"add_p",
53+
"max_p",
54+
"mul_p",
55+
"eq_p",
56+
"ne_p"};
57+
58+
const std::unordered_set<std::string> nondeterministic_operators{
59+
"dropout",
60+
"dropout_nd",
61+
"gaussian_random_batch_size_like",
62+
"gaussian_random",
63+
"randint",
64+
"random_crop",
65+
"random_routing",
66+
"randperm",
67+
"uniform_random_batch_size_like",
68+
"uniform_random_inplace",
69+
"uniform_random",
70+
"fused_bias_dropout_residual_layer_norm"};
71+
72+
const std::unordered_set<std::string> side_effect_operators{
73+
"feed", "cast", "fetch", "fill_constant", "fill_constant_batch_size_like"};
74+
75+
template <class T>
76+
inline void HashCombine(std::size_t *seed, const T &v) {
77+
std::hash<T> hasher;
78+
(*seed) ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2);
79+
}
80+
81+
} // namespace
82+
83+
namespace std {
84+
85+
#define HASH_ATTRIBUTE(attr, id, type) \
86+
do { \
87+
if (attr.index() == id) { \
88+
return std::hash<type>{}(get<id>(attr)); \
89+
} \
90+
} while (0)
91+
92+
#define HASH_VECTOR_ATTRIBUTE(attr, id, type) \
93+
do { \
94+
if (attr.index() == id) { \
95+
std::vector<type> vec = get<id>(attr); \
96+
size_t seed = 0; \
97+
for (const auto &v : vec) { \
98+
HashCombine(&seed, v); \
99+
} \
100+
return seed; \
101+
} \
102+
} while (0)
103+
104+
template <>
105+
struct hash<paddle::framework::proto::VarType_Type> {
106+
size_t operator()(const paddle::framework::proto::VarType_Type &attr) const {
107+
using type = typename std::underlying_type<
108+
paddle::framework::proto::VarType_Type>::type;
109+
return std::hash<type>()(static_cast<type>(attr));
110+
}
111+
};
112+
113+
template <>
114+
struct hash<paddle::framework::Attribute> {
115+
size_t operator()(const paddle::framework::Attribute &attr) const {
116+
if (attr.index() == 0) {
117+
return 0;
118+
}
119+
if (attr.index() == 7) {
120+
return static_cast<size_t>(get<7>(attr));
121+
}
122+
123+
HASH_ATTRIBUTE(attr, 1, int);
124+
HASH_ATTRIBUTE(attr, 2, float);
125+
HASH_ATTRIBUTE(attr, 3, std::string);
126+
HASH_VECTOR_ATTRIBUTE(attr, 4, int);
127+
HASH_VECTOR_ATTRIBUTE(attr, 5, float);
128+
HASH_VECTOR_ATTRIBUTE(attr, 6, std::string);
129+
HASH_ATTRIBUTE(attr, 8, std::vector<bool>);
130+
HASH_ATTRIBUTE(attr, 9, paddle::framework::BlockDesc *);
131+
HASH_ATTRIBUTE(attr, 10, int64_t);
132+
HASH_VECTOR_ATTRIBUTE(attr, 11, paddle::framework::BlockDesc *);
133+
HASH_VECTOR_ATTRIBUTE(attr, 12, int64_t);
134+
HASH_VECTOR_ATTRIBUTE(attr, 13, double);
135+
return 0;
136+
}
137+
};
138+
} // namespace std
139+
140+
namespace paddle {
141+
namespace framework {
142+
namespace ir {
143+
144+
void CommonSubexpressionEliminationPass::ApplyImpl(ir::Graph *graph) const {
145+
PADDLE_ENFORCE_EQ(
146+
graph->IsMainGraph(),
147+
true,
148+
platform::errors::InvalidArgument(
149+
"CommonSubexpressionEliminationPass only accepts main graph"));
150+
151+
CommonSubexpressionEliminate(
152+
graph, graph, [](Node *) -> Node * { return nullptr; });
153+
}
154+
155+
void CommonSubexpressionEliminationPass::CommonSubexpressionEliminate(
156+
ir::Graph *main_graph,
157+
ir::Graph *graph,
158+
std::function<Node *(Node *)> parent_exist_nodes) const {
159+
const char *kSubBlock = "sub_block";
160+
std::unordered_set<ir::Node *, HashOpNode, EqualOpNode> exist_nodes;
161+
std::vector<Node *> nodes = TopologySortOperations(*graph);
162+
for (Node *node : nodes) {
163+
if (node->inputs.empty()) {
164+
continue;
165+
}
166+
if (side_effect_operators.count(node->Name()) != 0) {
167+
continue;
168+
}
169+
if (nondeterministic_operators.count(node->Name()) != 0) {
170+
continue;
171+
}
172+
173+
if (node->Op()->HasAttr(kSubBlock)) {
174+
auto sub_block_id =
175+
node->Op()->GetAttrIfExists<BlockDesc *>(kSubBlock)->ID();
176+
CommonSubexpressionEliminate(
177+
main_graph,
178+
main_graph->GetSubGraph(sub_block_id),
179+
[&exist_nodes, &parent_exist_nodes](Node *node) -> Node * {
180+
auto exist_node = exist_nodes.find(node);
181+
if (exist_node != exist_nodes.end()) {
182+
return *exist_node;
183+
}
184+
return parent_exist_nodes(node);
185+
});
186+
continue;
187+
}
188+
189+
Node *exist_node = parent_exist_nodes(node);
190+
if (exist_node == nullptr) {
191+
auto res = exist_nodes.insert(node);
192+
if (!res.second) {
193+
exist_node = *res.first;
194+
}
195+
}
196+
197+
if (exist_node != nullptr) {
198+
for (size_t i = 0; i < exist_node->outputs.size(); ++i) {
199+
Node *exist_node_output = exist_node->outputs[i];
200+
Node *current_node_output = node->outputs[i];
201+
std::vector<Node *> current_node_output_outputs =
202+
current_node_output->outputs;
203+
for (size_t i = 0; i < current_node_output_outputs.size(); ++i) {
204+
IR_NODE_LINK_TO(exist_node_output, current_node_output_outputs[i]);
205+
}
206+
}
207+
GraphSafeRemoveNodes(graph,
208+
std::unordered_set<const Node *>(
209+
node->outputs.begin(), node->outputs.end()));
210+
GraphSafeRemoveNodes(graph, {node});
211+
}
212+
}
213+
}
214+
215+
size_t HashOpNode::operator()(const Node *node) const {
216+
PADDLE_ENFORCE_EQ(node->IsOp(),
217+
true,
218+
platform::errors::InvalidArgument(
219+
"HashOpNode only supports operation node type"));
220+
221+
size_t seed = 0;
222+
std::vector<Node *> inputs(node->inputs);
223+
if (commutative_operators.count(node->Name()) != 0) {
224+
auto comparator = [](Node *a, Node *b) { return a->Name() > b->Name(); };
225+
std::stable_sort(inputs.begin(), inputs.end(), comparator);
226+
}
227+
for (size_t i = 0; i < inputs.size(); ++i) {
228+
HashCombine(&seed, inputs[i]->id());
229+
HashCombine(&seed, node->GraphId());
230+
}
231+
const std::string kDepVarName = std::string(Node::kControlDepVarName);
232+
for (size_t i = 0; i < node->outputs.size(); ++i) {
233+
if (node->outputs[i] == nullptr) {
234+
continue;
235+
}
236+
if (node->outputs[i]->IsCtrlVar()) {
237+
HashCombine(&seed, kDepVarName);
238+
} else if (node->outputs[i]->IsVar()) {
239+
HashCombine(&seed, node->outputs[i]->Var()->GetType());
240+
}
241+
}
242+
OpDesc *desc = node->Op();
243+
std::vector<std::string> attributes = desc->AttrNames();
244+
sort(attributes.begin(), attributes.end());
245+
for (const std::string &attribute : attributes) {
246+
HashCombine(&seed, desc->GetAttr(attribute));
247+
}
248+
return seed;
249+
}
250+
251+
bool EqualOpNode::operator()(const Node *lhs, const Node *rhs) const {
252+
PADDLE_ENFORCE_EQ(lhs->IsOp() && rhs->IsOp(),
253+
true,
254+
platform::errors::InvalidArgument(
255+
"EqualOpNode only supports operation node type"));
256+
257+
if (lhs == nullptr && rhs == nullptr) {
258+
return true;
259+
}
260+
if (lhs == nullptr || rhs == nullptr) {
261+
return false;
262+
}
263+
if (lhs->NodeType() != rhs->NodeType()) {
264+
return false;
265+
}
266+
if (lhs->Name() != rhs->Name()) {
267+
return false;
268+
}
269+
270+
std::vector<Node *> lhs_inputs(lhs->inputs);
271+
std::vector<Node *> rhs_inputs(rhs->inputs);
272+
if (commutative_operators.count(lhs->Name()) != 0) {
273+
auto comparator = [](Node *a, Node *b) { return a->Name() > b->Name(); };
274+
std::stable_sort(lhs_inputs.begin(), lhs_inputs.end(), comparator);
275+
std::stable_sort(rhs_inputs.begin(), rhs_inputs.end(), comparator);
276+
}
277+
278+
// compare inputs value
279+
if (lhs_inputs.size() != rhs_inputs.size()) {
280+
return false;
281+
}
282+
if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) {
283+
return false;
284+
}
285+
286+
// compare attribute
287+
const OpDesc *lhs_desc = lhs->Op();
288+
const OpDesc *rhs_desc = rhs->Op();
289+
std::vector<std::string> lhs_attr_names = lhs_desc->AttrNames();
290+
std::vector<std::string> rhs_attr_names = rhs_desc->AttrNames();
291+
if (lhs_attr_names.size() != rhs_attr_names.size()) {
292+
return false;
293+
}
294+
std::sort(lhs_attr_names.begin(), lhs_attr_names.end());
295+
std::sort(rhs_attr_names.begin(), rhs_attr_names.end());
296+
for (size_t i = 0; i < lhs_attr_names.size(); ++i) {
297+
if (lhs_attr_names[i] != rhs_attr_names[i]) {
298+
return false;
299+
}
300+
if (lhs_desc->GetAttr(lhs_attr_names[i]) !=
301+
rhs_desc->GetAttr(rhs_attr_names[i])) {
302+
return false;
303+
}
304+
}
305+
306+
// compare outputs value type
307+
std::vector<Node *> lhs_outputs(lhs->outputs);
308+
std::vector<Node *> rhs_outputs(rhs->outputs);
309+
if (lhs_outputs.size() != rhs_outputs.size()) {
310+
return false;
311+
}
312+
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
313+
if (!lhs_outputs[i]->IsVar() || !rhs_outputs[i]->IsVar()) {
314+
return false;
315+
}
316+
if (lhs_outputs[i]->IsCtrlVar() != rhs_outputs[i]->IsCtrlVar()) {
317+
return false;
318+
}
319+
if (lhs_outputs[i]->IsCtrlVar() && rhs_outputs[i]->IsCtrlVar()) {
320+
continue;
321+
}
322+
if (lhs_outputs[i]->Var()->GetType() != rhs_outputs[i]->Var()->GetType()) {
323+
return false;
324+
}
325+
}
326+
return true;
327+
}
328+
329+
} // namespace ir
330+
} // namespace framework
331+
} // namespace paddle
332+
333+
REGISTER_PASS(common_subexpression_elimination_pass,
334+
paddle::framework::ir::CommonSubexpressionEliminationPass);
335+
REGISTER_PASS_CAPABILITY(common_subexpression_elimination_pass);

0 commit comments

Comments
 (0)