1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "node_functions.h"
17
18 #include "node_registry.h"
19 #include <message_parcel.h>
20 #include <v1_0/node_attr_types.h>
21
22 namespace OHOS {
23 namespace HDI {
24 namespace Nnrt {
25 namespace V1_0 {
GetAddPrimitive(const std::vector<int8_t> & primitive)26 PrimUniquePtr GetAddPrimitive(const std::vector<int8_t>& primitive)
27 {
28 AddFusion addAttr;
29 auto ret = ParsePrimitive<AddFusion>(primitive, addAttr, AddFusionBlockUnmarshalling);
30 if (ret != HDF_SUCCESS) {
31 HDF_LOGE("Parse primitive data of AddFusion operator failed.");
32 return nullptr;
33 }
34
35 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
36 prim->value.type = mindspore::schema::PrimitiveType_AddFusion;
37 auto attr = new (std::nothrow) mindspore::schema::AddFusionT;
38 if (attr == nullptr) {
39 HDF_LOGE("Create AddFusion primitive failed.");
40 return nullptr;
41 }
42 attr->activation_type = static_cast<mindspore::schema::ActivationType>(addAttr.activationType);
43 prim->value.value = attr;
44 return prim;
45 }
46
GetAvgPoolPrimitive(const std::vector<int8_t> & primitive)47 PrimUniquePtr GetAvgPoolPrimitive(const std::vector<int8_t>& primitive)
48 {
49 AvgPoolFusion avgPoolAttr;
50 auto ret = ParsePrimitive<AvgPoolFusion>(primitive, avgPoolAttr, AvgPoolFusionBlockUnmarshalling);
51 if (ret != HDF_SUCCESS) {
52 HDF_LOGE("Parse primitive data of AvgPoolFusion operator failed.");
53 return nullptr;
54 }
55
56 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
57 prim->value.type = mindspore::schema::PrimitiveType_AvgPoolFusion;
58
59 auto attr = new (std::nothrow) mindspore::schema::AvgPoolFusionT;
60 if (attr == nullptr) {
61 HDF_LOGE("Create AvgPoolFusion primitive failed.");
62 return nullptr;
63 }
64 attr->kernel_size = avgPoolAttr.kernelSize;
65 attr->strides = avgPoolAttr.strides;
66 attr->pad = avgPoolAttr.pad;
67 attr->pad_mode = static_cast<mindspore::schema::PadMode>(avgPoolAttr.padMode);
68 attr->round_mode = static_cast<mindspore::schema::RoundMode>(avgPoolAttr.roundMode);
69 attr->format = static_cast<mindspore::schema::Format>(avgPoolAttr.format);
70 attr->global = avgPoolAttr.global;
71 attr->activation_type = static_cast<mindspore::schema::ActivationType>(avgPoolAttr.activationType);
72 prim->value.value = attr;
73 return prim;
74 }
75
GetConcatPrimitive(const std::vector<int8_t> & primitive)76 PrimUniquePtr GetConcatPrimitive(const std::vector<int8_t>& primitive)
77 {
78 Concat concatAttr;
79 auto ret = ParsePrimitive<Concat>(primitive, concatAttr, ConcatBlockUnmarshalling);
80 if (ret != HDF_SUCCESS) {
81 HDF_LOGE("Parse primitive data of Concat operator failed.");
82 return nullptr;
83 }
84
85 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
86 prim->value.type = mindspore::schema::PrimitiveType_Concat;
87
88 auto attr = new (std::nothrow) mindspore::schema::ConcatT;
89 if (attr == nullptr) {
90 HDF_LOGE("Create concat primitive failed.");
91 return nullptr;
92 }
93 attr->axis = concatAttr.axis;
94 prim->value.value = attr;
95 return prim;
96 }
97
GetConv2dPrimitive(const std::vector<int8_t> & primitive)98 PrimUniquePtr GetConv2dPrimitive(const std::vector<int8_t>& primitive)
99 {
100 Conv2DFusion conv2dAttr;
101 auto ret = ParsePrimitive<Conv2DFusion>(primitive, conv2dAttr, Conv2DFusionBlockUnmarshalling);
102 if (ret != HDF_SUCCESS) {
103 HDF_LOGE("Parse primitive data of Conv2DFusion operator failed.");
104 return nullptr;
105 }
106
107 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
108 prim->value.type = mindspore::schema::PrimitiveType_Conv2DFusion;
109
110 auto attr = new (std::nothrow) mindspore::schema::Conv2DFusionT;
111 if (attr == nullptr) {
112 HDF_LOGE("Create Conv2DFusion primitive failed.");
113 return nullptr;
114 }
115
116 attr->kernel_size = conv2dAttr.kernelSize;
117 attr->stride = conv2dAttr.stride;
118 attr->dilation = conv2dAttr.dilation;
119 attr->pad_mode = static_cast<mindspore::schema::PadMode>(conv2dAttr.padMode);
120 attr->pad_list = conv2dAttr.padList;
121 attr->group = conv2dAttr.group;
122 attr->in_channel = conv2dAttr.inChannel;
123 attr->out_channel = conv2dAttr.outChannel;
124 attr->activation_type = static_cast<mindspore::schema::ActivationType>(conv2dAttr.activationType);
125
126 prim->value.value = attr;
127 return prim;
128 }
129
GetFullConnectionPrimitive(const std::vector<int8_t> & primitive)130 PrimUniquePtr GetFullConnectionPrimitive(const std::vector<int8_t>& primitive)
131 {
132 FullConnection fullConnAttr;
133 auto ret = ParsePrimitive<FullConnection>(primitive, fullConnAttr, FullConnectionBlockUnmarshalling);
134 if (ret != HDF_SUCCESS) {
135 HDF_LOGE("Parse primitive data of FullConnection operator failed.");
136 return nullptr;
137 }
138
139 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
140 prim->value.type = mindspore::schema::PrimitiveType_FullConnection;
141
142 auto attr = new (std::nothrow) mindspore::schema::FullConnectionT;
143 if (attr == nullptr) {
144 HDF_LOGE("Create FullConnection primitive failed.");
145 return nullptr;
146 }
147
148 attr->has_bias = fullConnAttr.hasBias;
149 attr->use_axis = fullConnAttr.useAxis;
150 attr->axis = fullConnAttr.axis;
151 attr->activation_type = static_cast<mindspore::schema::ActivationType>(fullConnAttr.activationType);
152
153 prim->value.value = attr;
154 return prim;
155 }
156
GetMaxPoolFusionPrimitive(const std::vector<int8_t> & primitive)157 PrimUniquePtr GetMaxPoolFusionPrimitive(const std::vector<int8_t>& primitive)
158 {
159 MaxPoolFusion maxPoolAttr;
160 auto ret = ParsePrimitive<MaxPoolFusion>(primitive, maxPoolAttr, MaxPoolFusionBlockUnmarshalling);
161 if (ret != HDF_SUCCESS) {
162 HDF_LOGE("Parse primitive data of MaxPoolFusion operator failed.");
163 return nullptr;
164 }
165
166 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
167 prim->value.type = mindspore::schema::PrimitiveType_MaxPoolFusion;
168
169 auto attr = new (std::nothrow) mindspore::schema::MaxPoolFusionT;
170 if (attr == nullptr) {
171 HDF_LOGE("Create MaxPoolFusion primitive failed.");
172 return nullptr;
173 }
174
175 attr->kernel_size = maxPoolAttr.kernelSize;
176 attr->strides = maxPoolAttr.strides;
177 attr->pad = maxPoolAttr.pad;
178 attr->pad_mode = static_cast<mindspore::schema::PadMode>(maxPoolAttr.padMode);
179 attr->format = static_cast<mindspore::schema::Format>(maxPoolAttr.format);
180 attr->global = maxPoolAttr.global;
181 attr->activation_type = static_cast<mindspore::schema::ActivationType>(maxPoolAttr.activationType);
182
183 prim->value.value = attr;
184 return prim;
185 }
186
GetMatMulFusionPrimitive(const std::vector<int8_t> & primitive)187 PrimUniquePtr GetMatMulFusionPrimitive(const std::vector<int8_t>& primitive)
188 {
189 MatMulFusion matmulAttr;
190 auto ret = ParsePrimitive<MatMulFusion>(primitive, matmulAttr, MatMulFusionBlockUnmarshalling);
191 if (ret != HDF_SUCCESS) {
192 HDF_LOGE("Parse primitive data of MatMulFusion operator failed.");
193 return nullptr;
194 }
195
196 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
197 prim->value.type = mindspore::schema::PrimitiveType_MatMulFusion;
198
199 auto attr = new (std::nothrow) mindspore::schema::MatMulFusionT;
200 if (attr == nullptr) {
201 HDF_LOGE("Create MatMulFusion primitive failed.");
202 return nullptr;
203 }
204
205 attr->transpose_a = matmulAttr.transposeA;
206 attr->transpose_b = matmulAttr.transposeB;
207 attr->activation_type = static_cast<mindspore::schema::ActivationType>(matmulAttr.activationType);
208
209 prim->value.value = attr;
210 return prim;
211 }
212
GetSoftmaxPrimitive(const std::vector<int8_t> & primitive)213 PrimUniquePtr GetSoftmaxPrimitive(const std::vector<int8_t>& primitive)
214 {
215 Softmax softmaxAttr;
216 auto ret = ParsePrimitive<Softmax>(primitive, softmaxAttr, SoftmaxBlockUnmarshalling);
217 if (ret != HDF_SUCCESS) {
218 HDF_LOGE("Parse primitive data of Softmax operator failed.");
219 return nullptr;
220 }
221
222 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
223 prim->value.type = mindspore::schema::PrimitiveType_Softmax;
224
225 auto attr = new (std::nothrow) mindspore::schema::SoftmaxT;
226 if (attr == nullptr) {
227 HDF_LOGE("Create Softmax primitive failed.");
228 return nullptr;
229 }
230
231 attr->axis = softmaxAttr.axis;
232 prim->value.value = attr;
233 return prim;
234 }
235
GetReshapePrimitive(const std::vector<int8_t> & primitive)236 PrimUniquePtr GetReshapePrimitive(const std::vector<int8_t>& primitive)
237 {
238 Reshape reshapeAttr;
239 auto ret = ParsePrimitive<Reshape>(primitive, reshapeAttr, ReshapeBlockUnmarshalling);
240 if (ret != HDF_SUCCESS) {
241 HDF_LOGE("Parse primitive data of Reshape operator failed.");
242 return nullptr;
243 }
244
245 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
246 prim->value.type = mindspore::schema::PrimitiveType_Reshape;
247
248 auto attr = new (std::nothrow) mindspore::schema::ReshapeT;
249 if (attr == nullptr) {
250 HDF_LOGE("Create Reshape primitive failed.");
251 return nullptr;
252 }
253
254 prim->value.value = attr;
255 return prim;
256 }
257
GetScaleFusionPrimitive(const std::vector<int8_t> & primitive)258 PrimUniquePtr GetScaleFusionPrimitive(const std::vector<int8_t>& primitive)
259 {
260 ScaleFusion scaleAttr;
261 auto ret = ParsePrimitive<ScaleFusion>(primitive, scaleAttr, ScaleFusionBlockUnmarshalling);
262 if (ret != HDF_SUCCESS) {
263 HDF_LOGE("Parse primitive data of ScaleFusion operator failed.");
264 return nullptr;
265 }
266
267 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
268 prim->value.type = mindspore::schema::PrimitiveType_ScaleFusion;
269
270 auto attr = new (std::nothrow) mindspore::schema::ScaleFusionT;
271 if (attr == nullptr) {
272 HDF_LOGE("Create ScaleFusion primitive failed.");
273 return nullptr;
274 }
275
276 attr->axis = scaleAttr.axis;
277 attr->activation_type = static_cast<mindspore::schema::ActivationType>(scaleAttr.activationType);
278 prim->value.value = attr;
279 return prim;
280 }
281
GetActivationPrimitive(const std::vector<int8_t> & primitive)282 PrimUniquePtr GetActivationPrimitive(const std::vector<int8_t>& primitive)
283 {
284 Activation actAttr;
285 auto ret = ParsePrimitive<Activation>(primitive, actAttr, ActivationBlockUnmarshalling);
286 if (ret != HDF_SUCCESS) {
287 HDF_LOGE("Parse primitive data of Activation operator failed.");
288 return nullptr;
289 }
290
291 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
292 prim->value.type = mindspore::schema::PrimitiveType_Activation;
293
294 auto attr = new (std::nothrow) mindspore::schema::ActivationT;
295 if (attr == nullptr) {
296 HDF_LOGE("Create Activation primitive failed.");
297 return nullptr;
298 }
299
300 attr->alpha = actAttr.alpha;
301 attr->min_val = actAttr.minVal;
302 attr->max_val = actAttr.maxVal;
303 attr->approximate = actAttr.approximate;
304 attr->activation_type = static_cast<mindspore::schema::ActivationType>(actAttr.activationType);
305
306 prim->value.value = attr;
307 return prim;
308 }
309
GetQuantDTypeCastPrimitive(const std::vector<int8_t> & primitive)310 PrimUniquePtr GetQuantDTypeCastPrimitive(const std::vector<int8_t>& primitive)
311 {
312 QuantDTypeCast quantAttr;
313 auto ret = ParsePrimitive<QuantDTypeCast>(primitive, quantAttr, QuantDTypeCastBlockUnmarshalling);
314 if (ret != HDF_SUCCESS) {
315 HDF_LOGE("Parse primitive data of QuantDTypeCast operator failed.");
316 return nullptr;
317 }
318
319 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
320 prim->value.type = mindspore::schema::PrimitiveType_QuantDTypeCast;
321
322 auto attr = new (std::nothrow) mindspore::schema::QuantDTypeCastT;
323 if (attr == nullptr) {
324 HDF_LOGE("Create QuantDTypeCast primitive failed.");
325 return nullptr;
326 }
327
328 attr->src_t = quantAttr.srcT;
329 attr->dst_t = quantAttr.dstT;
330 prim->value.value = attr;
331 return prim;
332 }
333
GetMulFusionPrimitive(const std::vector<int8_t> & primitive)334 PrimUniquePtr GetMulFusionPrimitive(const std::vector<int8_t>& primitive)
335 {
336 MulFusion mulAttr;
337 auto ret = ParsePrimitive<MulFusion>(primitive, mulAttr, MulFusionBlockUnmarshalling);
338 if (ret != HDF_SUCCESS) {
339 HDF_LOGE("Parse primitive data of MulFusion operator failed.");
340 return nullptr;
341 }
342
343 auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
344 prim->value.type = mindspore::schema::PrimitiveType_MulFusion;
345
346 auto attr = new (std::nothrow) mindspore::schema::MulFusionT;
347 if (attr == nullptr) {
348 HDF_LOGE("Create MulFusion primitive failed.");
349 return nullptr;
350 }
351
352 attr->activation_type = static_cast<mindspore::schema::ActivationType>(mulAttr.activationType);
353 prim->value.value = attr;
354 return prim;
355 }
356
357 REGISTER_NODE(Activation, NodeType::NODE_TYPE_ACTIVATION, GetActivationPrimitive);
358 REGISTER_NODE(AddFusion, NodeType::NODE_TYPE_ADD_FUSION, GetAddPrimitive);
359 REGISTER_NODE(AvgPoolFusion, NodeType::NODE_TYPE_AVGPOOL_FUSION, GetAvgPoolPrimitive);
360 REGISTER_NODE(Concat, NodeType::NODE_TYPE_CONCAT, GetConcatPrimitive);
361 REGISTER_NODE(Conv2DFusion, NodeType::NODE_TYPE_CONV2D_FUSION, GetConv2dPrimitive);
362 REGISTER_NODE(FullConnection, NodeType::NODE_TYPE_FULL_CONNECTION, GetFullConnectionPrimitive);
363 REGISTER_NODE(MaxPoolFusion, NodeType::NODE_TYPE_MAX_POOL_FUSION, GetMaxPoolFusionPrimitive);
364 REGISTER_NODE(MatMulFusion, NodeType::NODE_TYPE_MATMUL_FUSION, GetMatMulFusionPrimitive);
365 REGISTER_NODE(Reshape, NodeType::NODE_TYPE_RESHAPE, GetReshapePrimitive);
366 REGISTER_NODE(Softmax, NodeType::NODE_TYPE_SOFTMAX, GetSoftmaxPrimitive);
367 REGISTER_NODE(ScaleFusion, NodeType::NODE_TYPE_SCALE_FUSION, GetScaleFusionPrimitive);
368 REGISTER_NODE(QuantDTypeCast, NodeType::NODE_TYPE_QUANT_DTYPE_CAST, GetQuantDTypeCastPrimitive);
369 REGISTER_NODE(MulFusion, NodeType::NODE_TYPE_MUL_FUSION, GetMulFusionPrimitive);
370 } // namespace V1_0
371 } // namespace Nnrt
372 } // namespace HDI
373 } // namespace OHOS