• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_OPS_DECODER_LAYER_H_
17 #define MINDSPORE_CORE_OPS_DECODER_LAYER_H_
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "mindapi/base/types.h"
24 #include "ops/base_operator.h"
25 #include "plugin/device/cpu/kernel/nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace ops {
29 constexpr auto kNameDecoderLayer = "DecoderLayer";
30 /// \brief DecoderLayer op in MindIR.
31 class MIND_API DecoderLayer : public BaseOperator {
32  public:
33   MIND_API_BASE_MEMBER(DecoderLayer);
34   /// \brief Constructor.
DecoderLayer()35   DecoderLayer() : BaseOperator(kNameDecoderLayer) {
36     InitIOName({"input",
37                 "gamma1",
38                 "beta1",
39                 "weight_qkv",
40                 "bias_attn_qkv",
41                 "input_mask",
42                 "weight_attn_o",
43                 "bias_attn_o",
44                 "gamma2",
45                 "beta2",
46                 "encoder_output",
47                 "weight_attn_q",
48                 "weight_attn_kv",
49                 "bias_attn_cross_qkv",
50                 "cross_mask",
51                 "weight_attn_cross_o",
52                 "bias_attn_cross_o",
53                 "gamma3",
54                 "beta3",
55                 "weight_m",
56                 "bias_m",
57                 "weight_p",
58                 "bias_p"},
59                {"output"});
60   }
61   /// \brief Initialize DecoderLayer op.
62   /// \param[in] head_num Define head number.
63   /// \param[in] head_size Define size per head.
64   /// \param[in] eps_layernorm1 Define eps layernorm1.
65   /// \param[in] eps_layernorm2 Define eps layernorm2.
66   /// \param[in] eps_layernorm3 Define eps layernorm3.
67   /// \param[in] eps_layernorm4 Define eps layernorm4.
68   /// \param[in] ffn_hidden_size Define ffn hidden size.
69   /// \param[in] position_bias1 Define position_bias1.
70   /// \param[in] position_bias2 Define position_bias2.
71   /// \param[in] scale1 Define scale1.
72   /// \param[in] scale2 Define scale2.
73   /// \param[in] act_type Define act_type.
74   /// \param[in] layer_norm Define act_type.
75   void Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2, float eps_layernorm3,
76             float eps_layernorm4, int64_t ffn_hidden_size, bool position_bias1, bool position_bias2,
77             bool post_layernorm, float scale1 = 1.0f, float scale2 = 1.0f, ActType act_type = ActType::ActType_Gelu,
78             bool layer_norm = false);
79   void set_head_num(int64_t head_num);
80   void set_head_size(int64_t head_size);
81   void set_post_layernorm(bool post_layernorm);
82   void set_eps_layernorm1(float eps_layernorm1);
83   void set_eps_layernorm2(float eps_layernorm2);
84   void set_eps_layernorm3(float eps_layernorm3);
85   void set_eps_layernorm4(float eps_layernorm4);
86   void set_ffn_hidden_size(int64_t ffn_hidden_size);
87   void set_position_bias1(bool position_bias1);
88   void set_position_bias2(bool position_bias2);
89   void set_scale1(float scale1);
90   void set_scale2(float scale2);
91   void set_act_type(ActType act_type);
92   void set_layer_norm(bool layer_norm);
93   int64_t get_head_num() const;
94   int64_t get_head_size() const;
95   bool get_post_layernorm() const;
96   float get_eps_layernorm1() const;
97   float get_eps_layernorm2() const;
98   float get_eps_layernorm3() const;
99   float get_eps_layernorm4() const;
100   int64_t get_ffn_hidden_size() const;
101   bool get_position_bias1() const;
102   bool get_position_bias2() const;
103   float get_scale1() const;
104   float get_scale2() const;
105   ActType get_act_type() const;
106   bool get_layer_norm() const;
107 };
108 }  // namespace ops
109 }  // namespace mindspore
110 #endif  // MINDSPORE_CORE_OPS_DECODER_LAYER_H_
111