• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <vector>
19 #include "backend/common/graph_kernel/expanders/op_desc_registry.h"
20 #include "utils/ms_context.h"
21 
22 namespace mindspore::graphkernel::expanders {
23 namespace {
24 constexpr double csv_value = 0.044714998453855515;
25 // gelu(x) = 0.5 * x *(1 + Erf(x/sqrt(2)))
26 // gelu(x) = 0.5 * x *(1 + tanh(y)), y = sqrt(2/pi)*(x + 0.044715*x*x*x)
27 // Since in AKG or Ascend, there is no basic instruction for tanh, it is formed by combining basic instructions.
28 // Therefore, we expand tanh(x).
29 // tanh(x) = (e^x - e^{-x})/(e^x + e^{-x}) = 2/(1 + e^{-2x}) - 1
30 // gelu(x) = 0.5 *x * (1 + 2/(1 + e^{-2y}) - 1) = x/(1 + e^{-2y})
31 // After expanding, we find that the number of basic operators has reduced from 14 to 8, and memory can be reused
32 // (only input and output memory are needed to complete one GELU operation). Reflected in AKG, for 1024x1024 float16
33 // performing GELU, the original expansion on 910B required 64 cores, while the new code only needs 32 cores.
34 // Moreover, the basic instructions have been significantly reduced, leading to an 18% improvement in kernel
35 // performance.
GeLUByTanh(const inner::GraphBuilder & gb,const NodePtr & input_x,const TypeId dtype)36 NodePtr GeLUByTanh(const inner::GraphBuilder &gb, const NodePtr &input_x, const TypeId dtype) {
37   // np.sqrt(2/np.pi)
38   constexpr double csv_value_sqrt_two_div_pi = 0.7978845608028654;
39 
40   // cal y
41   auto mul_0 = gb.Mul(input_x, input_x);
42   auto pow_0 = gb.Mul(mul_0, input_x);
43   auto const_csvalue = gb.Tensor(csv_value, dtype);
44   auto mul_1 = gb.Mul(pow_0, const_csvalue);
45   auto tanh_res = gb.Add(input_x, mul_1);
46   auto const_csvalue_sqrt_two_div_pi = gb.Tensor(csv_value_sqrt_two_div_pi, dtype);
47   auto y = gb.Mul(tanh_res, const_csvalue_sqrt_two_div_pi);
48 
49   // cal gelu(x)
50   auto tanh_y = gb.Tanh(y);
51   auto const_one = gb.Tensor(1, dtype);
52   auto const_half = gb.Tensor(0.5, dtype);
53   auto tanh_y_add_one = gb.Add(tanh_y, const_one);
54   auto mul_x = gb.Mul(input_x, tanh_y_add_one);
55   auto result = gb.Mul(mul_x, const_half);
56 
57   return result;
58 }
59 
GeLUAscend(const inner::GraphBuilder & gb,const NodePtr & input,const TypeId dtype)60 NodePtr GeLUAscend(const inner::GraphBuilder &gb, const NodePtr &input, const TypeId dtype) {
61   // -np.sqrt(8/np.pi)
62   constexpr double csv_value_sqrt_eight_div_pi = -1.5957691216057308;
63   auto input_x = input;
64   if (dtype != kNumberTypeFloat32) {
65     input_x = gb.Cast(input_x, kNumberTypeFloat32);
66   }
67   auto mul_0 = gb.Mul(input_x, input_x);
68   auto pow_0 = gb.Mul(mul_0, input_x);
69   auto const_csvalue = gb.Tensor(csv_value, kNumberTypeFloat32);
70   auto mul_1 = gb.Mul(pow_0, const_csvalue);
71   auto tanh_res = gb.Add(input_x, mul_1);
72   auto const_csvalue_sqrt_eight_div_pi = gb.Tensor(csv_value_sqrt_eight_div_pi, kNumberTypeFloat32);
73   auto y = gb.Mul(tanh_res, const_csvalue_sqrt_eight_div_pi);
74 
75   auto exp_0 = gb.Exp(y);
76   auto const_one = gb.Tensor(1, kNumberTypeFloat32);
77   auto add_0 = gb.Add(exp_0, const_one);
78   auto result = gb.Div(input_x, add_0);
79   if (dtype != kNumberTypeFloat32) {
80     result = gb.Cast(result, dtype);
81   }
82   return result;
83 }
84 }  // namespace
85 class GeLU : public OpDesc {
86  public:
87   GeLU() = default;
88   ~GeLU() = default;
89 
90  protected:
Expand(const NodePtrList & inputs)91   NodePtrList Expand(const NodePtrList &inputs) override {
92     if (processor_ == "aicore") {
93       return {GeLUAscend(gb, inputs[0], inputs[0]->type)};
94     } else {
95       return {GeLUByTanh(gb, inputs[0], inputs[0]->type)};
96     }
97   }
98 };
99 EXPANDER_OP_DESC_REGISTER("GeLU", GeLU);
100 
GeluExpand(const inner::GraphBuilder & gb,const NodePtrList & inputs)101 NodePtr GeluExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs) {
102   return GeLUAscend(gb, inputs[0], inputs[0]->type);
103 }
104 }  // namespace mindspore::graphkernel::expanders
105