• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
17 #define MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
18 #include <map>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "mindapi/base/types.h"
24 #include "ops/base_operator.h"
25 #include "plugin/device/cpu/kernel/nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace ops {
29 constexpr auto kNameEncoderLayer = "EncoderLayer";
30 /// \brief EncoderLayer op in MindIR.
31 class MIND_API EncoderLayer : public BaseOperator {
32  public:
33   MIND_API_BASE_MEMBER(EncoderLayer);
34   /// \brief Constructor.
EncoderLayer()35   EncoderLayer() : BaseOperator(kNameEncoderLayer) {
36     InitIOName({"input", "gamma1", "beta1", "weight_attn_qkv", "bias_attn_qkv", "mask", "weight_attn_o", "bias_attn_o",
37                 "gamma2", "beta2", "weight_m", "bias_m", "weight_p", "bias_p"},
38                {"output"});
39   }
40   /// \brief Initialize EncoderLayer op.
41   /// \param[in] head_num Define head number.
42   /// \param[in] head_size Define size per head.
43   /// \param[in] eps_layernorm1 Define eps layernorm1.
44   /// \param[in] eps_layernorm2 Define eps layernorm2.
45   /// \param[in] eps_layernorm3 Define eps layernorm3.
46   /// \param[in] ffn_hidden_size Define ffn hidden size.
47   /// \param[in] expert_num Define expert num.
48   /// \param[in] expert_offset_id Define expert_offset_id.
49   /// \param[in] capacity_factor Define capacity_factor.
50   /// \param[in] position_bias Define position_bias.
51   /// \param[in] scale Define scale.
52   /// \param[in] act_type Define act_type.
53   /// \param[in] layer_norm Define act_type.
54   /// \param[in] use_past Define use_past.
55   /// \param[in] query_layer Define query_layer.
56   /// \param[in] moe Define moe.
57   /// \param[in] embedding_layer Define embedding_layer.
58 
59   void Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2, float eps_layernorm3,
60             int64_t ffn_hidden_size, int64_t expert_num, int64_t expert_offset_id, float capacity_factor,
61             bool position_bias, bool post_layernorm, float scale = 1.0f, ActType act_type = ActType::ActType_Gelu,
62             bool layer_norm = false, bool use_past = false, bool query_layer = false, bool moe = false,
63             bool embedding_layer = false);
64   void set_head_num(int64_t head_num);
65   void set_head_size(int64_t head_size);
66   void set_post_layernorm(bool post_layernorm);
67   void set_eps_layernorm1(float eps_layernorm1);
68   void set_eps_layernorm2(float eps_layernorm2);
69   void set_eps_layernorm3(float eps_layernorm3);
70   void set_ffn_hidden_size(int64_t ffn_hidden_size);
71   void set_expert_num(int64_t expert_num);
72   void set_expert_offset_id(int64_t expert_offset_id);
73   void set_capacity_factor(float capacity_factor);
74   void set_position_bias(bool position_bias);
75   void set_scale(float scale);
76   void set_act_type(ActType act_type);
77   void set_layer_norm(bool layer_norm);
78   void set_use_past(bool use_past);
79   void set_query_layer(bool query_layer);
80   void set_moe(bool moe);
81   void set_embedding_layer(bool embedding_layer);
82 
83   int64_t get_head_num() const;
84   int64_t get_head_size() const;
85   bool get_post_layernorm() const;
86   float get_eps_layernorm1() const;
87   float get_eps_layernorm2() const;
88   float get_eps_layernorm3() const;
89   int64_t get_ffn_hidden_size() const;
90   int64_t get_expert_num() const;
91   int64_t get_expert_offset_id() const;
92   float get_capacity_factor() const;
93   bool get_position_bias() const;
94   float get_scale() const;
95   ActType get_act_type() const;
96   bool get_layer_norm() const;
97   bool get_use_past() const;
98   bool get_query_layer() const;
99   bool get_moe() const;
100   bool get_embedding_layer() const;
101 };
102 }  // namespace ops
103 }  // namespace mindspore
104 #endif  // MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
105