1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "mapper/op_mapper.h"
18 #include <functional>
19 #include <algorithm>
20 #include "ops/tuple_get_item.h"
21 #include "common/op_attr.h"
22 #include "common/op_enum.h"
23 #include "common/anf_util.h"
24 #include "common/string_util.h"
25 #include "common/graph_output_name_keeper.h"
26 #include "third_party/securec/include/securec.h"
27
28 namespace mindspore {
29 namespace dpico {
30 namespace {
SetOpInputs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator)31 STATUS SetOpInputs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator) {
32 if (base_operator == nullptr) {
33 MS_LOG(ERROR) << "base_operator is nullptr.";
34 return RET_ERROR;
35 }
36 std::vector<std::string> input_names;
37 for (size_t i = 1; i < cnode->size(); i++) {
38 auto input_anode = cnode->input(i);
39 MS_ASSERT(input_anode != nullptr);
40 if (api::utils::isa<api::ParameterPtr>(input_anode)) {
41 auto param_node = input_anode->cast<api::ParameterPtr>();
42 if (param_node != nullptr && !param_node->has_default()) { // graph input
43 (void)input_names.emplace_back(input_anode->fullname_with_scope());
44 }
45 } else if (api::utils::isa<api::CNodePtr>(input_anode)) {
46 auto input_cnode = input_anode->cast<api::CNodePtr>();
47 if (input_cnode == nullptr) {
48 MS_LOG(ERROR) << "input node must be cnode.";
49 return RET_ERROR;
50 }
51 auto node_name = input_cnode->fullname_with_scope();
52 if (input_cnode->GetAttr(kOutputsNames) != nullptr) {
53 auto output_names = api::GetValue<std::vector<std::string>>(input_cnode->GetAttr(kOutputsNames));
54 if (output_names.size() == 1) {
55 node_name = output_names.at(0);
56 }
57 }
58 auto ret = dpico::GraphOutputNameKeeper::GetInstance()->DetermineOmOpInputName(input_cnode, &node_name);
59 MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine om op's input name failed.");
60 (void)input_names.emplace_back(node_name);
61 }
62 }
63 base_operator->SetInputNamesVec(input_names);
64 return RET_OK;
65 }
66
FillMultiOutOpOutputs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,const api::CNodePtrList & output_cnodes)67 STATUS FillMultiOutOpOutputs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator,
68 const api::CNodePtrList &output_cnodes) {
69 MS_ASSERT(base_operator != nullptr);
70 if (std::any_of(output_cnodes.begin(), output_cnodes.end(), [](const api::CNodePtr &cnode) {
71 return !CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>());
72 })) {
73 MS_LOG(ERROR) << "multi-out op must be connected with tuple-get-item node.";
74 return RET_ERROR;
75 }
76 auto abstract = cnode->abstract();
77 if (abstract == nullptr) {
78 MS_LOG(ERROR) << "each node's abstract must be not a nullptr.";
79 return RET_ERROR;
80 }
81 if (!abstract->isa<api::AbstractTuple>()) {
82 MS_LOG(ERROR) << "multi-out op's abstract must be a tuple.";
83 return RET_ERROR;
84 }
85 auto abstract_tuple = abstract->cast<api::AbstractTuplePtr>();
86 MS_ASSERT(abstract_tuple != nullptr);
87 auto output_num = abstract_tuple->elements().size();
88 std::vector<std::string> output_names;
89 // pre-fill the output names, because maybe there are unused outputs.
90 for (size_t i = 0; i < output_num; ++i) {
91 (void)output_names.emplace_back(cnode->fullname_with_scope() + "_unused_" + std::to_string(i));
92 }
93 for (const auto &output_cnode : output_cnodes) {
94 if (output_cnode->size() != kInputIndex3) {
95 MS_LOG(ERROR) << "tuple-get_item's inputs size must be 3.";
96 return RET_ERROR;
97 }
98 auto index_node = output_cnode->input(kInputIndex2);
99 MS_CHECK_TRUE_MSG(index_node != nullptr, RET_ERROR, "node is nullptr.");
100 auto value_ptr = api::GetValueNode(index_node);
101 MS_CHECK_TRUE_MSG(value_ptr != nullptr, RET_ERROR, "tuple_get_item's second input must be a value.");
102 auto num_str = value_ptr->ToString();
103 MS_CHECK_TRUE_MSG(IsValidUnsignedNum(num_str), RET_ERROR, "tuple_get_item's second input must be an unsigned int");
104 auto index = stoi(num_str);
105 MS_CHECK_TRUE_MSG(index >= 0 && static_cast<size_t>(index) < output_num, RET_ERROR,
106 "tuple_get_item index is invalid.");
107 std::string om_output_name = output_cnode->fullname_with_scope();
108 auto ret = GraphOutputNameKeeper::GetInstance()->DetermineOmOpOutputName(cnode, &om_output_name);
109 MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "cannot determine the om op's output name.");
110 output_names[index] = om_output_name;
111 }
112 base_operator->SetOutputNamesVec(output_names);
113 return RET_OK;
114 }
115
SetOpOutputs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,const api::CNodePtrList & output_cnodes)116 STATUS SetOpOutputs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator,
117 const api::CNodePtrList &output_cnodes) {
118 if (cnode == nullptr || base_operator == nullptr ||
119 std::any_of(output_cnodes.begin(), output_cnodes.end(),
120 [](const api::CNodePtr &cnode) { return cnode == nullptr; })) {
121 MS_LOG(ERROR) << "the function exist that input parameter is a nullptr.";
122 return RET_ERROR;
123 }
124 if (std::all_of(output_cnodes.begin(), output_cnodes.end(), [](const api::CNodePtr &cnode) {
125 return !CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>());
126 })) {
127 // single output op
128 std::vector<std::string> output_names;
129 std::string om_output_name = cnode->fullname_with_scope();
130 auto ret = GraphOutputNameKeeper::GetInstance()->DetermineOmOpOutputName(cnode, &om_output_name);
131 MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "cannot determine the om op's output name.");
132 (void)output_names.emplace_back(om_output_name);
133 base_operator->SetOutputNamesVec(output_names);
134 return RET_OK;
135 }
136
137 // multi output op
138 if (FillMultiOutOpOutputs(cnode, base_operator, output_cnodes) != RET_OK) {
139 MS_LOG(ERROR) << "set multi-out op's output names failed.";
140 return RET_ERROR;
141 }
142 return RET_OK;
143 }
144 } // namespace
145
SetCommonAttr(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,const api::CNodePtrList & output_cnodes)146 STATUS SetCommonAttr(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator,
147 const api::CNodePtrList &output_cnodes) {
148 if (base_operator == nullptr) {
149 MS_LOG(ERROR) << "base operator is nullptr.";
150 return RET_ERROR;
151 }
152 base_operator->SetOpName(cnode->fullname_with_scope());
153 if (SetOpInputs(cnode, base_operator) != RET_OK) {
154 MS_LOG(ERROR) << "set op inputs failed. " << cnode->fullname_with_scope();
155 return RET_ERROR;
156 }
157 if (SetOpOutputs(cnode, base_operator, output_cnodes) != RET_OK) {
158 MS_LOG(ERROR) << "set op outputs failed. " << cnode->fullname_with_scope();
159 return RET_ERROR;
160 }
161 return RET_OK;
162 }
163
SetConvFcDataInfo(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator)164 STATUS SetConvFcDataInfo(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator) {
165 if (base_operator == nullptr) {
166 MS_LOG(ERROR) << "base_operator is nullptr.";
167 return RET_ERROR;
168 }
169 for (size_t i = 2; i < cnode->size(); i++) {
170 auto input_node = cnode->input(i);
171 MS_ASSERT(input_node != nullptr);
172 auto param_node = input_node->cast<api::ParameterPtr>();
173 if (param_node == nullptr || !param_node->has_default()) {
174 continue;
175 }
176 auto tensor_info = param_node->default_param()->cast<api::TensorPtr>();
177 if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
178 auto data = reinterpret_cast<float *>(tensor_info->data());
179 MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr.");
180 if (i == kInputIndex2) {
181 base_operator->SetWeightDataPtr(data);
182 base_operator->SetWeightSize(tensor_info->DataSize());
183 } else if (i == kInputIndex3) {
184 base_operator->SetBiasDataPtr(data);
185 base_operator->SetBiasSize(tensor_info->DataSize());
186 } else {
187 MS_LOG(ERROR) << "conv or fc operator only support 2 offline inputs at most, but "
188 << cnode->fullname_with_scope() << " has " << i << " offline inputs.";
189 return RET_ERROR;
190 }
191 } else {
192 MS_LOG(ERROR) << "param node's tensor info is invalid. " << input_node->fullname_with_scope();
193 return RET_ERROR;
194 }
195 }
196
197 return RET_OK;
198 }
SetRecurrentDataInfo(const api::CNodePtr & cnode,mapper::RecurrentOperator * recurrent_operator)199 STATUS SetRecurrentDataInfo(const api::CNodePtr &cnode, mapper::RecurrentOperator *recurrent_operator) {
200 if (recurrent_operator == nullptr) {
201 MS_LOG(ERROR) << "recurrent_operator is nullptr.";
202 return RET_ERROR;
203 }
204 for (size_t i = 1; i < cnode->size(); i++) {
205 auto input_node = cnode->input(i);
206 if (api::utils::isa<api::CNode>(input_node)) {
207 MS_LOG(INFO) << "cnode don't have blobs";
208 continue;
209 }
210 if (api::utils::isa<api::ParameterPtr>(input_node)) {
211 auto input_param_node = input_node->cast<api::ParameterPtr>();
212 if (!input_param_node->has_default()) {
213 MS_LOG(INFO) << "graph input don't have blobs";
214 continue;
215 }
216 auto tensor_info = input_param_node->default_param()->cast<api::TensorPtr>();
217 if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
218 auto raw_datas = static_cast<float *>(tensor_info->data());
219 auto elem_count = tensor_info->DataSize();
220 auto weight_data = new (std::nothrow) float[tensor_info->DataSize()];
221 if (weight_data == nullptr) {
222 MS_LOG(ERROR) << "new float[] failed.";
223 return RET_ERROR;
224 }
225 if (memcpy_s(weight_data, static_cast<size_t>(tensor_info->DataSize()) * sizeof(float), raw_datas,
226 static_cast<size_t>(tensor_info->DataSize()) * sizeof(float)) != EOK) {
227 MS_LOG(ERROR) << "memcpy_s failed.";
228 delete[] weight_data;
229 return RET_ERROR;
230 }
231 recurrent_operator->AddRecurrentParamVec(weight_data);
232 recurrent_operator->AddRecurrentParamLengthVec(elem_count);
233 } else {
234 MS_LOG(ERROR) << "tensor_info is nullptr, or DataSize equals zero. " << cnode->fullname_with_scope();
235 return RET_ERROR;
236 }
237 }
238 }
239 return RET_OK;
240 }
SetRecurrentOnnxInfo(const api::CNodePtr & cnode,mapper::RecurrentOperator * recurrent_operator)241 STATUS SetRecurrentOnnxInfo(const api::CNodePtr &cnode, mapper::RecurrentOperator *recurrent_operator) {
242 if (recurrent_operator == nullptr) {
243 MS_LOG(ERROR) << "recurrent_operator is nullptr.";
244 return RET_ERROR;
245 }
246 for (size_t i = 1; i < cnode->size(); i++) {
247 auto input_node = cnode->input(i);
248 if (api::utils::isa<api::CNode>(input_node)) {
249 MS_LOG(INFO) << "cnode don't have blobs";
250 continue;
251 }
252 if (api::utils::isa<api::ParameterPtr>(input_node)) {
253 auto input_param_node = input_node->cast<api::ParameterPtr>();
254 if (!input_param_node->has_default()) {
255 MS_LOG(INFO) << "graph input don't have blobs";
256 continue;
257 }
258 auto tensor_info = input_param_node->default_param()->cast<api::TensorPtr>();
259 if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
260 auto raw_datas = static_cast<float *>(tensor_info->data());
261 auto shape = tensor_info->shape();
262 vector<int32_t> shape_vec(shape.begin(), shape.end());
263 auto weight_data = new (std::nothrow) float[tensor_info->DataSize()];
264 if (weight_data == nullptr) {
265 MS_LOG(ERROR) << "new float[] failed.";
266 return RET_ERROR;
267 }
268 if (memcpy_s(weight_data, static_cast<size_t>(tensor_info->DataSize()) * sizeof(float), raw_datas,
269 static_cast<size_t>(tensor_info->DataSize()) * sizeof(float)) != EOK) {
270 MS_LOG(ERROR) << "memcpy_s failed.";
271 delete[] weight_data;
272 return RET_ERROR;
273 }
274 if (SetOnnxLstmOffLineArgs(recurrent_operator, i, shape_vec, weight_data) != RET_OK) {
275 MS_LOG(ERROR) << "set offline args failed.";
276 return RET_ERROR;
277 }
278 if (i == kDims5) {
279 std::vector<std::pair<std::vector<float>, std::vector<int32_t>>> offline_args;
280 std::vector<float> offline_data;
281 recurrent_operator->PushOfflineArgs({});
282 if (CheckTensorInfoType(tensor_info, &offline_data) != RET_OK) {
283 MS_LOG(ERROR) << "check tensor_info type failed.";
284 return RET_ERROR;
285 }
286 std::vector<int32_t> offline_shape;
287 ShapeVector shape_vector;
288 if (GetShapeVectorFromParameter(input_param_node, &shape_vector) != RET_OK) {
289 MS_LOG(ERROR) << "get shape vector from parameter failed. " << input_param_node->fullname_with_scope();
290 return RET_ERROR;
291 }
292 (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(offline_shape),
293 [](const int64_t dim) { return static_cast<int32_t>(dim); });
294 (void)offline_args.emplace_back(std::make_pair(offline_data, offline_shape));
295 for (auto &offline_arg : offline_args) {
296 recurrent_operator->PushOfflineArgs(std::move(offline_arg));
297 }
298 }
299 } else {
300 MS_LOG(ERROR) << "tensor_info is nullptr, or DataSize equals zero. " << cnode->fullname_with_scope();
301 return RET_ERROR;
302 }
303 }
304 }
305 return RET_OK;
306 }
CheckTensorInfoType(const api::TensorPtr & tensor_info,std::vector<float> * offline_data)307 STATUS CheckTensorInfoType(const api::TensorPtr &tensor_info, std::vector<float> *offline_data) {
308 auto elem_count = tensor_info->DataSize();
309 if (tensor_info->data_type() == kNumberTypeInt32 || tensor_info->data_type() == kNumberTypeInt) {
310 auto raw_data = static_cast<int32_t *>(tensor_info->data());
311 *offline_data = std::vector<float>(raw_data, raw_data + elem_count);
312 } else if (tensor_info->data_type() == kNumberTypeFloat32 || tensor_info->data_type() == kNumberTypeFloat) {
313 auto raw_data = static_cast<float *>(tensor_info->data());
314 *offline_data = std::vector<float>(raw_data, raw_data + elem_count);
315 } else {
316 MS_LOG(ERROR) << "unsupported param type. " << tensor_info->data_type();
317 return RET_ERROR;
318 }
319 return RET_OK;
320 }
SetOnnxLstmOffLineArgs(mapper::RecurrentOperator * recurrent_operator,size_t index,const vector<int32_t> & shape_vec,const float * data)321 STATUS SetOnnxLstmOffLineArgs(mapper::RecurrentOperator *recurrent_operator, size_t index,
322 const vector<int32_t> &shape_vec, const float *data) {
323 if (index == kDims2) {
324 recurrent_operator->SetXtShapeVec(shape_vec);
325 recurrent_operator->SetXtWeightDataPtr(data);
326 } else if (index == kDims3) {
327 recurrent_operator->SetHtShapeVec(shape_vec);
328 recurrent_operator->SetHtWeightDataPtr(data);
329 } else if (index == kDims4) {
330 recurrent_operator->SetRecurrentBiasShapeVec(shape_vec);
331 recurrent_operator->SetRecurrentBiasDataPtr(data);
332 } else if (index == kDims8) {
333 recurrent_operator->SetPeepholesShapeVec(shape_vec);
334 recurrent_operator->SetPeepholesWeightDataPtr(data);
335 }
336 return RET_OK;
337 }
PushOfflineArgs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,size_t offline_args_size)338 STATUS PushOfflineArgs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator, size_t offline_args_size) {
339 if (base_operator == nullptr) {
340 MS_LOG(ERROR) << "base_operator is nullptr.";
341 return RET_ERROR;
342 }
343 if (offline_args_size > cnode->size()) {
344 MS_LOG(ERROR) << "input offline_args_size:" << offline_args_size
345 << " is greater than cnode input size:" << cnode->size() << " " << cnode->fullname_with_scope();
346 return RET_ERROR;
347 }
348 auto inputs_size = std::min(offline_args_size + 1, cnode->size());
349 std::vector<std::pair<std::vector<float>, std::vector<int32_t>>> offline_args;
350 bool has_offline_args = false;
351 for (size_t i = 1; i < inputs_size; i++) {
352 auto input_node = cnode->input(i);
353 if (api::utils::isa<api::CNode>(input_node)) {
354 MS_LOG(INFO) << "cnode don't have blobs";
355 (void)offline_args.emplace_back();
356 continue;
357 }
358 if (api::utils::isa<api::ParameterPtr>(input_node)) {
359 auto input_param_node = input_node->cast<api::ParameterPtr>();
360 if (!input_param_node->has_default()) {
361 MS_LOG(INFO) << "graph input don't have blobs";
362 (void)offline_args.emplace_back();
363 continue;
364 }
365 auto tensor_info = input_param_node->default_param()->cast<api::TensorPtr>();
366 if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
367 has_offline_args = true;
368 std::vector<float> offline_data;
369 auto elem_count = tensor_info->DataSize();
370 if (tensor_info->data_type() == kNumberTypeInt32 || tensor_info->data_type() == kNumberTypeInt) {
371 auto raw_datas = static_cast<int32_t *>(tensor_info->data());
372 offline_data = std::vector<float>(raw_datas, raw_datas + elem_count);
373 } else if (tensor_info->data_type() == kNumberTypeFloat32 || tensor_info->data_type() == kNumberTypeFloat) {
374 auto raw_datas = static_cast<float *>(tensor_info->data());
375 offline_data = std::vector<float>(raw_datas, raw_datas + elem_count);
376 } else {
377 MS_LOG(ERROR) << "unsupported param type. " << tensor_info->data_type();
378 return RET_ERROR;
379 }
380 std::vector<int32_t> offline_shape;
381 ShapeVector shape_vector;
382 if (GetShapeVectorFromParameter(input_param_node, &shape_vector) != RET_OK) {
383 MS_LOG(ERROR) << "get shape vector from parameter failed. " << input_param_node->fullname_with_scope();
384 return RET_ERROR;
385 }
386 (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(offline_shape),
387 [](const int64_t dim) { return static_cast<int32_t>(dim); });
388 (void)offline_args.emplace_back(std::make_pair(offline_data, offline_shape));
389 } else {
390 MS_LOG(ERROR) << "tensor_info is nullptr, or DataSize equals zero. " << cnode->fullname_with_scope();
391 return RET_ERROR;
392 }
393 }
394 }
395 if (has_offline_args) {
396 for (auto &offline_arg : offline_args) {
397 base_operator->PushOfflineArgs(std::move(offline_arg));
398 }
399 }
400 return RET_OK;
401 }
402 } // namespace dpico
403 } // namespace mindspore
404