1
2 /**
3 * Copyright 2021 Huawei Technologies Co., Ltd
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 #include "ops/encoder_layer.h"
19
20 #include "mindapi/base/shared_ptr.h"
21 #include "mindapi/ir/value.h"
22 #include "mindapi/src/helper.h"
23 #include "mindspore/core/ops/nn_ops.h"
24 #include "ops/op_name.h"
25 #include "ops/primitive_c.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore::ops {
29 MIND_API_OPERATOR_IMPL(EncoderLayer, BaseOperator);
30
set_head_num(int64_t head_num)31 void EncoderLayer::set_head_num(int64_t head_num) { (void)this->AddAttr(kNumHeads, api::MakeValue(head_num)); }
32
set_head_size(int64_t head_size)33 void EncoderLayer::set_head_size(int64_t head_size) { (void)this->AddAttr(kSizePerHead, api::MakeValue(head_size)); }
34
set_post_layernorm(bool post_layernorm)35 void EncoderLayer::set_post_layernorm(bool post_layernorm) {
36 (void)this->AddAttr(kPostLayernorm, api::MakeValue(post_layernorm));
37 }
38
set_eps_layernorm1(float eps_layernorm1)39 void EncoderLayer::set_eps_layernorm1(float eps_layernorm1) {
40 (void)this->AddAttr(kEpsLayerNorm1, api::MakeValue(eps_layernorm1));
41 }
42
set_eps_layernorm2(float eps_layernorm2)43 void EncoderLayer::set_eps_layernorm2(float eps_layernorm2) {
44 (void)this->AddAttr(kEpsLayerNorm2, api::MakeValue(eps_layernorm2));
45 }
46
set_eps_layernorm3(float eps_layernorm3)47 void EncoderLayer::set_eps_layernorm3(float eps_layernorm3) {
48 (void)this->AddAttr(kEpsLayerNorm3, api::MakeValue(eps_layernorm3));
49 }
set_ffn_hidden_size(int64_t ffn_hidden_size)50 void EncoderLayer::set_ffn_hidden_size(int64_t ffn_hidden_size) {
51 (void)this->AddAttr(kFfnHiddenSize, api::MakeValue(ffn_hidden_size));
52 }
set_expert_num(int64_t expert_num)53 void EncoderLayer::set_expert_num(int64_t expert_num) { (void)this->AddAttr(kExpertNum, api::MakeValue(expert_num)); }
set_expert_offset_id(int64_t expert_offset_id)54 void EncoderLayer::set_expert_offset_id(int64_t expert_offset_id) {
55 (void)this->AddAttr(kExpertOffsetId, api::MakeValue(expert_offset_id));
56 }
set_capacity_factor(float capacity_factor)57 void EncoderLayer::set_capacity_factor(float capacity_factor) {
58 (void)this->AddAttr(kCapacityFactor, api::MakeValue(capacity_factor));
59 }
set_position_bias(bool position_bias)60 void EncoderLayer::set_position_bias(bool position_bias) {
61 (void)this->AddAttr(kPositionBias1, api::MakeValue(position_bias));
62 }
set_scale(float scale)63 void EncoderLayer::set_scale(float scale) { (void)this->AddAttr(kScale, api::MakeValue(scale)); }
set_layer_norm(bool layer_norm)64 void EncoderLayer::set_layer_norm(bool layer_norm) { (void)this->AddAttr(kLayerNorm, api::MakeValue(layer_norm)); }
65
set_act_type(ActType act_type)66 void EncoderLayer::set_act_type(ActType act_type) {
67 (void)this->AddAttr(kActivationType, api::MakeValue(static_cast<int32_t>(act_type)));
68 }
set_use_past(bool use_past)69 void EncoderLayer::set_use_past(bool use_past) { (void)this->AddAttr(kUsePast, api::MakeValue(use_past)); }
set_query_layer(bool query_layer)70 void EncoderLayer::set_query_layer(bool query_layer) { (void)this->AddAttr(kQueryLayer, api::MakeValue(query_layer)); }
set_moe(bool moe)71 void EncoderLayer::set_moe(bool moe) { (void)this->AddAttr(kMoe, api::MakeValue(moe)); }
set_embedding_layer(bool embedding_layer)72 void EncoderLayer::set_embedding_layer(bool embedding_layer) {
73 (void)this->AddAttr(kEmbeddingLayer, api::MakeValue(embedding_layer));
74 }
75
get_head_num() const76 int64_t EncoderLayer::get_head_num() const {
77 auto value_ptr = this->GetAttr(kNumHeads);
78 return GetValue<int64_t>(value_ptr);
79 }
80
get_head_size() const81 int64_t EncoderLayer::get_head_size() const {
82 auto value_ptr = this->GetAttr(kSizePerHead);
83 return GetValue<int64_t>(value_ptr);
84 }
85
get_post_layernorm() const86 bool EncoderLayer::get_post_layernorm() const {
87 auto value_ptr = this->GetAttr(kPostLayernorm);
88 return GetValue<bool>(value_ptr);
89 }
90
get_eps_layernorm1() const91 float EncoderLayer::get_eps_layernorm1() const {
92 auto value_ptr = this->GetAttr(kEpsLayerNorm1);
93 return GetValue<float>(value_ptr);
94 }
95
get_eps_layernorm2() const96 float EncoderLayer::get_eps_layernorm2() const {
97 auto value_ptr = this->GetAttr(kEpsLayerNorm2);
98 return GetValue<float>(value_ptr);
99 }
100
get_eps_layernorm3() const101 float EncoderLayer::get_eps_layernorm3() const {
102 auto value_ptr = this->GetAttr(kEpsLayerNorm3);
103 return GetValue<float>(value_ptr);
104 }
105
get_ffn_hidden_size() const106 int64_t EncoderLayer::get_ffn_hidden_size() const {
107 auto value_ptr = this->GetAttr(kFfnHiddenSize);
108 return GetValue<int64_t>(value_ptr);
109 }
110
get_expert_num() const111 int64_t EncoderLayer::get_expert_num() const {
112 auto value_ptr = this->GetAttr(kExpertNum);
113 return GetValue<int64_t>(value_ptr);
114 }
115
get_expert_offset_id() const116 int64_t EncoderLayer::get_expert_offset_id() const {
117 auto value_ptr = this->GetAttr(kExpertOffsetId);
118 return GetValue<int64_t>(value_ptr);
119 }
120
get_capacity_factor() const121 float EncoderLayer::get_capacity_factor() const {
122 auto value_ptr = this->GetAttr(kCapacityFactor);
123 return GetValue<float>(value_ptr);
124 }
125
get_position_bias() const126 bool EncoderLayer::get_position_bias() const {
127 auto value_ptr = this->GetAttr(kPositionBias1);
128 return GetValue<bool>(value_ptr);
129 }
130
get_scale() const131 float EncoderLayer::get_scale() const {
132 auto value_ptr = this->GetAttr(kScale);
133 return GetValue<float>(value_ptr);
134 }
135
get_act_type() const136 ActType EncoderLayer::get_act_type() const {
137 auto value_ptr = GetAttr(kActivationType);
138 if (value_ptr == nullptr) {
139 return ActType::ActType_No;
140 }
141 return ActType(GetValue<int64_t>(value_ptr));
142 }
143
get_layer_norm() const144 bool EncoderLayer::get_layer_norm() const {
145 auto value_ptr = this->GetAttr(kLayerNorm);
146 return GetValue<bool>(value_ptr);
147 }
148
get_use_past() const149 bool EncoderLayer::get_use_past() const {
150 auto value_ptr = this->GetAttr(kUsePast);
151 return GetValue<bool>(value_ptr);
152 }
153
get_query_layer() const154 bool EncoderLayer::get_query_layer() const {
155 auto value_ptr = this->GetAttr(kQueryLayer);
156 return GetValue<bool>(value_ptr);
157 }
158
get_moe() const159 bool EncoderLayer::get_moe() const {
160 auto value_ptr = this->GetAttr(kMoe);
161 return GetValue<bool>(value_ptr);
162 }
163
get_embedding_layer() const164 bool EncoderLayer::get_embedding_layer() const {
165 auto value_ptr = this->GetAttr(kEmbeddingLayer);
166 if (value_ptr == nullptr) {
167 return false;
168 }
169 return GetValue<bool>(value_ptr);
170 }
Init(int64_t head_num,int64_t head_size,float eps_layernorm1,float eps_layernorm2,float eps_layernorm3,int64_t ffn_hidden_size,int64_t expert_num,int64_t expert_offset_id,float capacity_factor,bool position_bias,bool post_layernorm,float scale,ActType act_type,bool layer_norm,bool use_past,bool query_layer,bool moe,bool embedding_layer)171 void EncoderLayer::Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2,
172 float eps_layernorm3, int64_t ffn_hidden_size, int64_t expert_num, int64_t expert_offset_id,
173 float capacity_factor, bool position_bias, bool post_layernorm, float scale, ActType act_type,
174 bool layer_norm, bool use_past, bool query_layer, bool moe, bool embedding_layer) {
175 this->set_head_num(head_num);
176 this->set_head_size(head_size);
177 this->set_post_layernorm(post_layernorm);
178 this->set_eps_layernorm1(eps_layernorm1);
179 this->set_eps_layernorm2(eps_layernorm2);
180 this->set_eps_layernorm3(eps_layernorm3);
181 this->set_ffn_hidden_size(ffn_hidden_size);
182 this->set_expert_num(expert_num);
183 this->set_expert_offset_id(expert_offset_id);
184 this->set_capacity_factor(capacity_factor);
185 this->set_position_bias(position_bias);
186 this->set_act_type(act_type);
187 this->set_scale(scale);
188 this->set_layer_norm(layer_norm);
189 this->set_use_past(use_past);
190 this->set_query_layer(query_layer);
191 this->set_moe(moe);
192 this->set_embedding_layer(embedding_layer);
193 }
194 REGISTER_PRIMITIVE_C(kNameEncoderLayer, EncoderLayer);
195 } // namespace mindspore::ops
196