@@ -8,10 +8,17 @@ namespace passes {
88
99void ReduceGelu (std::shared_ptr<torch::jit::Graph>& graph) {
1010 std::string gelu_pattern = R"IR(
11- graph(%x):
11+ graph(%x : Tensor ):
1212 %out : Tensor = aten::gelu(%x)
1313 return (%out))IR" ;
1414
15+ // This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
16+ // an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
17+ std::string gelu_approximate_pattern = R"IR(
18+ graph(%x : Tensor, %approx):
19+ %out : Tensor = aten::gelu(%x, %approx)
20+ return (%out))IR" ;
21+
1522 std::string gelu_reduce_pattern = R"IR(
1623 graph(%x.1 : Tensor):
1724 %6 : float = prim::Constant[value=0.044714999999999998]()
@@ -30,11 +37,36 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
3037 %15 : Tensor = aten::mul(%7, %14)
3138 return (%15))IR" ;
3239
40+ // This is same as gelu_reduce_pattern except for an additional input %approx.
41+ // SubgraphRewriter only works as expected if the number of inputs to gelu_approximate_pattern
42+ // and gelu_reduce_multi_input_pattern are same.
43+ std::string gelu_reduce_multi_input_pattern = R"IR(
44+ graph(%x.1 : Tensor, %approx):
45+ %6 : float = prim::Constant[value=0.044714999999999998]()
46+ %5 : float = prim::Constant[value=0.79788456080000003]()
47+ %4 : float = prim::Constant[value=1.]()
48+ %3 : float = prim::Constant[value=0.5]()
49+ %2 : int = prim::Constant[value=1]()
50+ %7 : Tensor = aten::mul(%x.1, %3)
51+ %8 : Tensor = aten::mul(%x.1, %5)
52+ %9 : Tensor = aten::mul(%x.1, %6)
53+ %10 : Tensor = aten::mul(%9, %x.1)
54+ %11 : Tensor = aten::add(%10, %4, %2)
55+ %12 : Tensor = aten::mul(%8, %11)
56+ %13 : Tensor = aten::tanh(%12)
57+ %14 : Tensor = aten::add(%13, %4, %2)
58+ %15 : Tensor = aten::mul(%7, %14)
59+ return (%15))IR" ;
60+
3361 // replace aten::gelu with pointwise operations
3462 torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
3563 map_gelu_to_pointwise_ops.RegisterRewritePattern (gelu_pattern, gelu_reduce_pattern);
3664 map_gelu_to_pointwise_ops.runOnGraph (graph);
3765
66+ torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
67+ map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern (gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
68+ map_gelu_approximate_to_pointwise_ops.runOnGraph (graph);
69+
3870 LOG_GRAPH (" Post lowering of [aten::gelu] -> " << *graph);
3971}
4072
0 commit comments