1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_OPS_ENCODER_LAYER_H_ 17 #define MINDSPORE_CORE_OPS_ENCODER_LAYER_H_ 18 #include <map> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "mindapi/base/types.h" 24 #include "ops/base_operator.h" 25 #include "plugin/device/cpu/kernel/nnacl/op_base.h" 26 27 namespace mindspore { 28 namespace ops { 29 constexpr auto kNameEncoderLayer = "EncoderLayer"; 30 /// \brief EncoderLayer op in MindIR. 31 class MIND_API EncoderLayer : public BaseOperator { 32 public: 33 MIND_API_BASE_MEMBER(EncoderLayer); 34 /// \brief Constructor. EncoderLayer()35 EncoderLayer() : BaseOperator(kNameEncoderLayer) { 36 InitIOName({"input", "gamma1", "beta1", "weight_attn_qkv", "bias_attn_qkv", "mask", "weight_attn_o", "bias_attn_o", 37 "gamma2", "beta2", "weight_m", "bias_m", "weight_p", "bias_p"}, 38 {"output"}); 39 } 40 /// \brief Initialize EncoderLayer op. 41 /// \param[in] head_num Define head number. 42 /// \param[in] head_size Define size per head. 43 /// \param[in] eps_layernorm1 Define eps layernorm1. 44 /// \param[in] eps_layernorm2 Define eps layernorm2. 45 /// \param[in] eps_layernorm3 Define eps layernorm3. 46 /// \param[in] ffn_hidden_size Define ffn hidden size. 47 /// \param[in] expert_num Define expert num. 48 /// \param[in] expert_offset_id Define expert_offset_id. 49 /// \param[in] capacity_factor Define capacity_factor. 50 /// \param[in] position_bias Define position_bias. 51 /// \param[in] scale Define scale. 52 /// \param[in] act_type Define act_type. 53 /// \param[in] layer_norm Define act_type. 54 /// \param[in] use_past Define use_past. 55 /// \param[in] query_layer Define query_layer. 56 /// \param[in] moe Define moe. 57 /// \param[in] embedding_layer Define embedding_layer. 58 59 void Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2, float eps_layernorm3, 60 int64_t ffn_hidden_size, int64_t expert_num, int64_t expert_offset_id, float capacity_factor, 61 bool position_bias, bool post_layernorm, float scale = 1.0f, ActType act_type = ActType::ActType_Gelu, 62 bool layer_norm = false, bool use_past = false, bool query_layer = false, bool moe = false, 63 bool embedding_layer = false); 64 void set_head_num(int64_t head_num); 65 void set_head_size(int64_t head_size); 66 void set_post_layernorm(bool post_layernorm); 67 void set_eps_layernorm1(float eps_layernorm1); 68 void set_eps_layernorm2(float eps_layernorm2); 69 void set_eps_layernorm3(float eps_layernorm3); 70 void set_ffn_hidden_size(int64_t ffn_hidden_size); 71 void set_expert_num(int64_t expert_num); 72 void set_expert_offset_id(int64_t expert_offset_id); 73 void set_capacity_factor(float capacity_factor); 74 void set_position_bias(bool position_bias); 75 void set_scale(float scale); 76 void set_act_type(ActType act_type); 77 void set_layer_norm(bool layer_norm); 78 void set_use_past(bool use_past); 79 void set_query_layer(bool query_layer); 80 void set_moe(bool moe); 81 void set_embedding_layer(bool embedding_layer); 82 83 int64_t get_head_num() const; 84 int64_t get_head_size() const; 85 bool get_post_layernorm() const; 86 float get_eps_layernorm1() const; 87 float get_eps_layernorm2() const; 88 float get_eps_layernorm3() const; 89 int64_t get_ffn_hidden_size() const; 90 int64_t get_expert_num() const; 91 int64_t get_expert_offset_id() const; 92 float get_capacity_factor() const; 93 bool get_position_bias() const; 94 float get_scale() const; 95 ActType get_act_type() const; 96 bool get_layer_norm() const; 97 bool get_use_past() const; 98 bool get_query_layer() const; 99 bool get_moe() const; 100 bool get_embedding_layer() const; 101 }; 102 } // namespace ops 103 } // namespace mindspore 104 #endif // MINDSPORE_CORE_OPS_ENCODER_LAYER_H_ 105