• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "node_functions.h"
17 
18 #include "node_registry.h"
19 #include <message_parcel.h>
20 #include <v1_0/node_attr_types.h>
21 
22 namespace OHOS {
23 namespace HDI {
24 namespace Nnrt {
25 namespace V1_0 {
GetAddPrimitive(const std::vector<int8_t> & primitive)26 PrimUniquePtr GetAddPrimitive(const std::vector<int8_t>& primitive)
27 {
28     AddFusion addAttr;
29     auto ret = ParsePrimitive<AddFusion>(primitive, addAttr, AddFusionBlockUnmarshalling);
30     if (ret != HDF_SUCCESS) {
31         HDF_LOGE("Parse primitive data of AddFusion operator failed.");
32         return nullptr;
33     }
34 
35     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
36     prim->value.type = mindspore::schema::PrimitiveType_AddFusion;
37     auto attr = new (std::nothrow) mindspore::schema::AddFusionT;
38     if (attr == nullptr) {
39         HDF_LOGE("Create AddFusion primitive failed.");
40         return nullptr;
41     }
42     attr->activation_type = static_cast<mindspore::schema::ActivationType>(addAttr.activationType);
43     prim->value.value = attr;
44     return prim;
45 }
46 
GetAvgPoolPrimitive(const std::vector<int8_t> & primitive)47 PrimUniquePtr GetAvgPoolPrimitive(const std::vector<int8_t>& primitive)
48 {
49     AvgPoolFusion avgPoolAttr;
50     auto ret = ParsePrimitive<AvgPoolFusion>(primitive, avgPoolAttr, AvgPoolFusionBlockUnmarshalling);
51     if (ret != HDF_SUCCESS) {
52         HDF_LOGE("Parse primitive data of AvgPoolFusion operator failed.");
53         return nullptr;
54     }
55 
56     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
57     prim->value.type = mindspore::schema::PrimitiveType_AvgPoolFusion;
58 
59     auto attr = new (std::nothrow) mindspore::schema::AvgPoolFusionT;
60     if (attr == nullptr) {
61         HDF_LOGE("Create AvgPoolFusion primitive failed.");
62         return nullptr;
63     }
64     attr->kernel_size = avgPoolAttr.kernelSize;
65     attr->strides = avgPoolAttr.strides;
66     attr->pad = avgPoolAttr.pad;
67     attr->pad_mode = static_cast<mindspore::schema::PadMode>(avgPoolAttr.padMode);
68     attr->round_mode = static_cast<mindspore::schema::RoundMode>(avgPoolAttr.roundMode);
69     attr->format = static_cast<mindspore::schema::Format>(avgPoolAttr.format);
70     attr->global = avgPoolAttr.global;
71     attr->activation_type = static_cast<mindspore::schema::ActivationType>(avgPoolAttr.activationType);
72     prim->value.value = attr;
73     return prim;
74 }
75 
GetConcatPrimitive(const std::vector<int8_t> & primitive)76 PrimUniquePtr GetConcatPrimitive(const std::vector<int8_t>& primitive)
77 {
78     Concat concatAttr;
79     auto ret = ParsePrimitive<Concat>(primitive, concatAttr, ConcatBlockUnmarshalling);
80     if (ret != HDF_SUCCESS) {
81         HDF_LOGE("Parse primitive data of Concat operator failed.");
82         return nullptr;
83     }
84 
85     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
86     prim->value.type = mindspore::schema::PrimitiveType_Concat;
87 
88     auto attr = new (std::nothrow) mindspore::schema::ConcatT;
89     if (attr == nullptr) {
90         HDF_LOGE("Create concat primitive failed.");
91         return nullptr;
92     }
93     attr->axis = concatAttr.axis;
94     prim->value.value = attr;
95     return prim;
96 }
97 
GetConv2dPrimitive(const std::vector<int8_t> & primitive)98 PrimUniquePtr GetConv2dPrimitive(const std::vector<int8_t>& primitive)
99 {
100     Conv2DFusion conv2dAttr;
101     auto ret = ParsePrimitive<Conv2DFusion>(primitive, conv2dAttr, Conv2DFusionBlockUnmarshalling);
102     if (ret != HDF_SUCCESS) {
103         HDF_LOGE("Parse primitive data of Conv2DFusion operator failed.");
104         return nullptr;
105     }
106 
107     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
108     prim->value.type = mindspore::schema::PrimitiveType_Conv2DFusion;
109 
110     auto attr = new (std::nothrow) mindspore::schema::Conv2DFusionT;
111     if (attr == nullptr) {
112         HDF_LOGE("Create Conv2DFusion primitive failed.");
113         return nullptr;
114     }
115 
116     attr->kernel_size = conv2dAttr.kernelSize;
117     attr->stride = conv2dAttr.stride;
118     attr->dilation = conv2dAttr.dilation;
119     attr->pad_mode = static_cast<mindspore::schema::PadMode>(conv2dAttr.padMode);
120     attr->pad_list = conv2dAttr.padList;
121     attr->group = conv2dAttr.group;
122     attr->in_channel = conv2dAttr.inChannel;
123     attr->out_channel = conv2dAttr.outChannel;
124     attr->activation_type = static_cast<mindspore::schema::ActivationType>(conv2dAttr.activationType);
125 
126     prim->value.value = attr;
127     return prim;
128 }
129 
GetFullConnectionPrimitive(const std::vector<int8_t> & primitive)130 PrimUniquePtr GetFullConnectionPrimitive(const std::vector<int8_t>& primitive)
131 {
132     FullConnection fullConnAttr;
133     auto ret = ParsePrimitive<FullConnection>(primitive, fullConnAttr, FullConnectionBlockUnmarshalling);
134     if (ret != HDF_SUCCESS) {
135         HDF_LOGE("Parse primitive data of FullConnection operator failed.");
136         return nullptr;
137     }
138 
139     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
140     prim->value.type = mindspore::schema::PrimitiveType_FullConnection;
141 
142     auto attr = new (std::nothrow) mindspore::schema::FullConnectionT;
143     if (attr == nullptr) {
144         HDF_LOGE("Create FullConnection primitive failed.");
145         return nullptr;
146     }
147 
148     attr->has_bias = fullConnAttr.hasBias;
149     attr->use_axis = fullConnAttr.useAxis;
150     attr->axis = fullConnAttr.axis;
151     attr->activation_type = static_cast<mindspore::schema::ActivationType>(fullConnAttr.activationType);
152 
153     prim->value.value = attr;
154     return prim;
155 }
156 
GetMaxPoolFusionPrimitive(const std::vector<int8_t> & primitive)157 PrimUniquePtr GetMaxPoolFusionPrimitive(const std::vector<int8_t>& primitive)
158 {
159     MaxPoolFusion maxPoolAttr;
160     auto ret = ParsePrimitive<MaxPoolFusion>(primitive, maxPoolAttr, MaxPoolFusionBlockUnmarshalling);
161     if (ret != HDF_SUCCESS) {
162         HDF_LOGE("Parse primitive data of MaxPoolFusion operator failed.");
163         return nullptr;
164     }
165 
166     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
167     prim->value.type = mindspore::schema::PrimitiveType_MaxPoolFusion;
168 
169     auto attr = new (std::nothrow) mindspore::schema::MaxPoolFusionT;
170     if (attr == nullptr) {
171         HDF_LOGE("Create MaxPoolFusion primitive failed.");
172         return nullptr;
173     }
174 
175     attr->kernel_size = maxPoolAttr.kernelSize;
176     attr->strides = maxPoolAttr.strides;
177     attr->pad = maxPoolAttr.pad;
178     attr->pad_mode = static_cast<mindspore::schema::PadMode>(maxPoolAttr.padMode);
179     attr->format = static_cast<mindspore::schema::Format>(maxPoolAttr.format);
180     attr->global = maxPoolAttr.global;
181     attr->activation_type = static_cast<mindspore::schema::ActivationType>(maxPoolAttr.activationType);
182 
183     prim->value.value = attr;
184     return prim;
185 }
186 
GetMatMulFusionPrimitive(const std::vector<int8_t> & primitive)187 PrimUniquePtr GetMatMulFusionPrimitive(const std::vector<int8_t>& primitive)
188 {
189     MatMulFusion matmulAttr;
190     auto ret = ParsePrimitive<MatMulFusion>(primitive, matmulAttr, MatMulFusionBlockUnmarshalling);
191     if (ret != HDF_SUCCESS) {
192         HDF_LOGE("Parse primitive data of MatMulFusion operator failed.");
193         return nullptr;
194     }
195 
196     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
197     prim->value.type = mindspore::schema::PrimitiveType_MatMulFusion;
198 
199     auto attr = new (std::nothrow) mindspore::schema::MatMulFusionT;
200     if (attr == nullptr) {
201         HDF_LOGE("Create MatMulFusion primitive failed.");
202         return nullptr;
203     }
204 
205     attr->transpose_a = matmulAttr.transposeA;
206     attr->transpose_b = matmulAttr.transposeB;
207     attr->activation_type = static_cast<mindspore::schema::ActivationType>(matmulAttr.activationType);
208 
209     prim->value.value = attr;
210     return prim;
211 }
212 
GetSoftmaxPrimitive(const std::vector<int8_t> & primitive)213 PrimUniquePtr GetSoftmaxPrimitive(const std::vector<int8_t>& primitive)
214 {
215     Softmax softmaxAttr;
216     auto ret = ParsePrimitive<Softmax>(primitive, softmaxAttr, SoftmaxBlockUnmarshalling);
217     if (ret != HDF_SUCCESS) {
218         HDF_LOGE("Parse primitive data of Softmax operator failed.");
219         return nullptr;
220     }
221 
222     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
223     prim->value.type = mindspore::schema::PrimitiveType_Softmax;
224 
225     auto attr = new (std::nothrow) mindspore::schema::SoftmaxT;
226     if (attr == nullptr) {
227         HDF_LOGE("Create Softmax primitive failed.");
228         return nullptr;
229     }
230 
231     attr->axis = softmaxAttr.axis;
232     prim->value.value = attr;
233     return prim;
234 }
235 
GetReshapePrimitive(const std::vector<int8_t> & primitive)236 PrimUniquePtr GetReshapePrimitive(const std::vector<int8_t>& primitive)
237 {
238     Reshape reshapeAttr;
239     auto ret = ParsePrimitive<Reshape>(primitive, reshapeAttr, ReshapeBlockUnmarshalling);
240     if (ret != HDF_SUCCESS) {
241         HDF_LOGE("Parse primitive data of Reshape operator failed.");
242         return nullptr;
243     }
244 
245     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
246     prim->value.type = mindspore::schema::PrimitiveType_Reshape;
247 
248     auto attr = new (std::nothrow) mindspore::schema::ReshapeT;
249     if (attr == nullptr) {
250         HDF_LOGE("Create Reshape primitive failed.");
251         return nullptr;
252     }
253 
254     prim->value.value = attr;
255     return prim;
256 }
257 
GetScaleFusionPrimitive(const std::vector<int8_t> & primitive)258 PrimUniquePtr GetScaleFusionPrimitive(const std::vector<int8_t>& primitive)
259 {
260     ScaleFusion scaleAttr;
261     auto ret = ParsePrimitive<ScaleFusion>(primitive, scaleAttr, ScaleFusionBlockUnmarshalling);
262     if (ret != HDF_SUCCESS) {
263         HDF_LOGE("Parse primitive data of ScaleFusion operator failed.");
264         return nullptr;
265     }
266 
267     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
268     prim->value.type = mindspore::schema::PrimitiveType_ScaleFusion;
269 
270     auto attr = new (std::nothrow) mindspore::schema::ScaleFusionT;
271     if (attr == nullptr) {
272         HDF_LOGE("Create ScaleFusion primitive failed.");
273         return nullptr;
274     }
275 
276     attr->axis = scaleAttr.axis;
277     attr->activation_type = static_cast<mindspore::schema::ActivationType>(scaleAttr.activationType);
278     prim->value.value = attr;
279     return prim;
280 }
281 
GetActivationPrimitive(const std::vector<int8_t> & primitive)282 PrimUniquePtr GetActivationPrimitive(const std::vector<int8_t>& primitive)
283 {
284     Activation actAttr;
285     auto ret = ParsePrimitive<Activation>(primitive, actAttr, ActivationBlockUnmarshalling);
286     if (ret != HDF_SUCCESS) {
287         HDF_LOGE("Parse primitive data of Activation operator failed.");
288         return nullptr;
289     }
290 
291     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
292     prim->value.type = mindspore::schema::PrimitiveType_Activation;
293 
294     auto attr = new (std::nothrow) mindspore::schema::ActivationT;
295     if (attr == nullptr) {
296         HDF_LOGE("Create Activation primitive failed.");
297         return nullptr;
298     }
299 
300     attr->alpha = actAttr.alpha;
301     attr->min_val = actAttr.minVal;
302     attr->max_val = actAttr.maxVal;
303     attr->approximate = actAttr.approximate;
304     attr->activation_type = static_cast<mindspore::schema::ActivationType>(actAttr.activationType);
305 
306     prim->value.value = attr;
307     return prim;
308 }
309 
GetQuantDTypeCastPrimitive(const std::vector<int8_t> & primitive)310 PrimUniquePtr GetQuantDTypeCastPrimitive(const std::vector<int8_t>& primitive)
311 {
312     QuantDTypeCast quantAttr;
313     auto ret = ParsePrimitive<QuantDTypeCast>(primitive, quantAttr, QuantDTypeCastBlockUnmarshalling);
314     if (ret != HDF_SUCCESS) {
315         HDF_LOGE("Parse primitive data of QuantDTypeCast operator failed.");
316         return nullptr;
317     }
318 
319     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
320     prim->value.type = mindspore::schema::PrimitiveType_QuantDTypeCast;
321 
322     auto attr = new (std::nothrow) mindspore::schema::QuantDTypeCastT;
323     if (attr == nullptr) {
324         HDF_LOGE("Create QuantDTypeCast primitive failed.");
325         return nullptr;
326     }
327 
328     attr->src_t = quantAttr.srcT;
329     attr->dst_t = quantAttr.dstT;
330     prim->value.value = attr;
331     return prim;
332 }
333 
GetMulFusionPrimitive(const std::vector<int8_t> & primitive)334 PrimUniquePtr GetMulFusionPrimitive(const std::vector<int8_t>& primitive)
335 {
336     MulFusion mulAttr;
337     auto ret = ParsePrimitive<MulFusion>(primitive, mulAttr, MulFusionBlockUnmarshalling);
338     if (ret != HDF_SUCCESS) {
339         HDF_LOGE("Parse primitive data of MulFusion operator failed.");
340         return nullptr;
341     }
342 
343     auto prim = std::make_unique<mindspore::schema::PrimitiveT>();
344     prim->value.type = mindspore::schema::PrimitiveType_MulFusion;
345 
346     auto attr = new (std::nothrow) mindspore::schema::MulFusionT;
347     if (attr == nullptr) {
348         HDF_LOGE("Create MulFusion primitive failed.");
349         return nullptr;
350     }
351 
352     attr->activation_type = static_cast<mindspore::schema::ActivationType>(mulAttr.activationType);
353     prim->value.value = attr;
354     return prim;
355 }
356 
357 REGISTER_NODE(Activation, NodeType::NODE_TYPE_ACTIVATION, GetActivationPrimitive);
358 REGISTER_NODE(AddFusion, NodeType::NODE_TYPE_ADD_FUSION, GetAddPrimitive);
359 REGISTER_NODE(AvgPoolFusion, NodeType::NODE_TYPE_AVGPOOL_FUSION, GetAvgPoolPrimitive);
360 REGISTER_NODE(Concat, NodeType::NODE_TYPE_CONCAT, GetConcatPrimitive);
361 REGISTER_NODE(Conv2DFusion, NodeType::NODE_TYPE_CONV2D_FUSION, GetConv2dPrimitive);
362 REGISTER_NODE(FullConnection, NodeType::NODE_TYPE_FULL_CONNECTION, GetFullConnectionPrimitive);
363 REGISTER_NODE(MaxPoolFusion, NodeType::NODE_TYPE_MAX_POOL_FUSION, GetMaxPoolFusionPrimitive);
364 REGISTER_NODE(MatMulFusion, NodeType::NODE_TYPE_MATMUL_FUSION, GetMatMulFusionPrimitive);
365 REGISTER_NODE(Reshape, NodeType::NODE_TYPE_RESHAPE, GetReshapePrimitive);
366 REGISTER_NODE(Softmax, NodeType::NODE_TYPE_SOFTMAX, GetSoftmaxPrimitive);
367 REGISTER_NODE(ScaleFusion, NodeType::NODE_TYPE_SCALE_FUSION, GetScaleFusionPrimitive);
368 REGISTER_NODE(QuantDTypeCast, NodeType::NODE_TYPE_QUANT_DTYPE_CAST, GetQuantDTypeCastPrimitive);
369 REGISTER_NODE(MulFusion, NodeType::NODE_TYPE_MUL_FUSION, GetMulFusionPrimitive);
370 } // namespace V1_0
371 } // namespace Nnrt
372 } // namespace HDI
373 } // namespace OHOS