• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "mapper/op_mapper.h"
18 #include <functional>
19 #include <algorithm>
20 #include "ops/tuple_get_item.h"
21 #include "common/op_attr.h"
22 #include "common/op_enum.h"
23 #include "common/anf_util.h"
24 #include "common/string_util.h"
25 #include "common/graph_output_name_keeper.h"
26 #include "third_party/securec/include/securec.h"
27 
28 namespace mindspore {
29 namespace dpico {
30 namespace {
SetOpInputs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator)31 STATUS SetOpInputs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator) {
32   if (base_operator == nullptr) {
33     MS_LOG(ERROR) << "base_operator is nullptr.";
34     return RET_ERROR;
35   }
36   std::vector<std::string> input_names;
37   for (size_t i = 1; i < cnode->size(); i++) {
38     auto input_anode = cnode->input(i);
39     MS_ASSERT(input_anode != nullptr);
40     if (api::utils::isa<api::ParameterPtr>(input_anode)) {
41       auto param_node = input_anode->cast<api::ParameterPtr>();
42       if (param_node != nullptr && !param_node->has_default()) {  // graph input
43         (void)input_names.emplace_back(input_anode->fullname_with_scope());
44       }
45     } else if (api::utils::isa<api::CNodePtr>(input_anode)) {
46       auto input_cnode = input_anode->cast<api::CNodePtr>();
47       if (input_cnode == nullptr) {
48         MS_LOG(ERROR) << "input node must be cnode.";
49         return RET_ERROR;
50       }
51       auto node_name = input_cnode->fullname_with_scope();
52       if (input_cnode->GetAttr(kOutputsNames) != nullptr) {
53         auto output_names = api::GetValue<std::vector<std::string>>(input_cnode->GetAttr(kOutputsNames));
54         if (output_names.size() == 1) {
55           node_name = output_names.at(0);
56         }
57       }
58       auto ret = dpico::GraphOutputNameKeeper::GetInstance()->DetermineOmOpInputName(input_cnode, &node_name);
59       MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine om op's input name failed.");
60       (void)input_names.emplace_back(node_name);
61     }
62   }
63   base_operator->SetInputNamesVec(input_names);
64   return RET_OK;
65 }
66 
FillMultiOutOpOutputs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,const api::CNodePtrList & output_cnodes)67 STATUS FillMultiOutOpOutputs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator,
68                              const api::CNodePtrList &output_cnodes) {
69   MS_ASSERT(base_operator != nullptr);
70   if (std::any_of(output_cnodes.begin(), output_cnodes.end(), [](const api::CNodePtr &cnode) {
71         return !CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>());
72       })) {
73     MS_LOG(ERROR) << "multi-out op must be connected with tuple-get-item node.";
74     return RET_ERROR;
75   }
76   auto abstract = cnode->abstract();
77   if (abstract == nullptr) {
78     MS_LOG(ERROR) << "each node's abstract must be not a nullptr.";
79     return RET_ERROR;
80   }
81   if (!abstract->isa<api::AbstractTuple>()) {
82     MS_LOG(ERROR) << "multi-out op's abstract must be a tuple.";
83     return RET_ERROR;
84   }
85   auto abstract_tuple = abstract->cast<api::AbstractTuplePtr>();
86   MS_ASSERT(abstract_tuple != nullptr);
87   auto output_num = abstract_tuple->elements().size();
88   std::vector<std::string> output_names;
89   // pre-fill the output names, because maybe there are unused outputs.
90   for (size_t i = 0; i < output_num; ++i) {
91     (void)output_names.emplace_back(cnode->fullname_with_scope() + "_unused_" + std::to_string(i));
92   }
93   for (const auto &output_cnode : output_cnodes) {
94     if (output_cnode->size() != kInputIndex3) {
95       MS_LOG(ERROR) << "tuple-get_item's inputs size must be 3.";
96       return RET_ERROR;
97     }
98     auto index_node = output_cnode->input(kInputIndex2);
99     MS_CHECK_TRUE_MSG(index_node != nullptr, RET_ERROR, "node is nullptr.");
100     auto value_ptr = api::GetValueNode(index_node);
101     MS_CHECK_TRUE_MSG(value_ptr != nullptr, RET_ERROR, "tuple_get_item's second input must be a value.");
102     auto num_str = value_ptr->ToString();
103     MS_CHECK_TRUE_MSG(IsValidUnsignedNum(num_str), RET_ERROR, "tuple_get_item's second input must be an unsigned int");
104     auto index = stoi(num_str);
105     MS_CHECK_TRUE_MSG(index >= 0 && static_cast<size_t>(index) < output_num, RET_ERROR,
106                       "tuple_get_item index is invalid.");
107     std::string om_output_name = output_cnode->fullname_with_scope();
108     auto ret = GraphOutputNameKeeper::GetInstance()->DetermineOmOpOutputName(cnode, &om_output_name);
109     MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "cannot determine the om op's output name.");
110     output_names[index] = om_output_name;
111   }
112   base_operator->SetOutputNamesVec(output_names);
113   return RET_OK;
114 }
115 
SetOpOutputs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,const api::CNodePtrList & output_cnodes)116 STATUS SetOpOutputs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator,
117                     const api::CNodePtrList &output_cnodes) {
118   if (cnode == nullptr || base_operator == nullptr ||
119       std::any_of(output_cnodes.begin(), output_cnodes.end(),
120                   [](const api::CNodePtr &cnode) { return cnode == nullptr; })) {
121     MS_LOG(ERROR) << "the function exist that input parameter is a nullptr.";
122     return RET_ERROR;
123   }
124   if (std::all_of(output_cnodes.begin(), output_cnodes.end(), [](const api::CNodePtr &cnode) {
125         return !CheckPrimitiveType(cnode, api::MakeShared<ops::TupleGetItem>());
126       })) {
127     // single output op
128     std::vector<std::string> output_names;
129     std::string om_output_name = cnode->fullname_with_scope();
130     auto ret = GraphOutputNameKeeper::GetInstance()->DetermineOmOpOutputName(cnode, &om_output_name);
131     MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "cannot determine the om op's output name.");
132     (void)output_names.emplace_back(om_output_name);
133     base_operator->SetOutputNamesVec(output_names);
134     return RET_OK;
135   }
136 
137   // multi output op
138   if (FillMultiOutOpOutputs(cnode, base_operator, output_cnodes) != RET_OK) {
139     MS_LOG(ERROR) << "set multi-out op's output names failed.";
140     return RET_ERROR;
141   }
142   return RET_OK;
143 }
144 }  // namespace
145 
SetCommonAttr(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,const api::CNodePtrList & output_cnodes)146 STATUS SetCommonAttr(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator,
147                      const api::CNodePtrList &output_cnodes) {
148   if (base_operator == nullptr) {
149     MS_LOG(ERROR) << "base operator is nullptr.";
150     return RET_ERROR;
151   }
152   base_operator->SetOpName(cnode->fullname_with_scope());
153   if (SetOpInputs(cnode, base_operator) != RET_OK) {
154     MS_LOG(ERROR) << "set op inputs failed. " << cnode->fullname_with_scope();
155     return RET_ERROR;
156   }
157   if (SetOpOutputs(cnode, base_operator, output_cnodes) != RET_OK) {
158     MS_LOG(ERROR) << "set op outputs failed. " << cnode->fullname_with_scope();
159     return RET_ERROR;
160   }
161   return RET_OK;
162 }
163 
SetConvFcDataInfo(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator)164 STATUS SetConvFcDataInfo(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator) {
165   if (base_operator == nullptr) {
166     MS_LOG(ERROR) << "base_operator is nullptr.";
167     return RET_ERROR;
168   }
169   for (size_t i = 2; i < cnode->size(); i++) {
170     auto input_node = cnode->input(i);
171     MS_ASSERT(input_node != nullptr);
172     auto param_node = input_node->cast<api::ParameterPtr>();
173     if (param_node == nullptr || !param_node->has_default()) {
174       continue;
175     }
176     auto tensor_info = param_node->default_param()->cast<api::TensorPtr>();
177     if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
178       auto data = reinterpret_cast<float *>(tensor_info->data());
179       MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr.");
180       if (i == kInputIndex2) {
181         base_operator->SetWeightDataPtr(data);
182         base_operator->SetWeightSize(tensor_info->DataSize());
183       } else if (i == kInputIndex3) {
184         base_operator->SetBiasDataPtr(data);
185         base_operator->SetBiasSize(tensor_info->DataSize());
186       } else {
187         MS_LOG(ERROR) << "conv or fc operator only support 2 offline inputs at most, but "
188                       << cnode->fullname_with_scope() << " has " << i << " offline inputs.";
189         return RET_ERROR;
190       }
191     } else {
192       MS_LOG(ERROR) << "param node's tensor info is invalid. " << input_node->fullname_with_scope();
193       return RET_ERROR;
194     }
195   }
196 
197   return RET_OK;
198 }
SetRecurrentDataInfo(const api::CNodePtr & cnode,mapper::RecurrentOperator * recurrent_operator)199 STATUS SetRecurrentDataInfo(const api::CNodePtr &cnode, mapper::RecurrentOperator *recurrent_operator) {
200   if (recurrent_operator == nullptr) {
201     MS_LOG(ERROR) << "recurrent_operator is nullptr.";
202     return RET_ERROR;
203   }
204   for (size_t i = 1; i < cnode->size(); i++) {
205     auto input_node = cnode->input(i);
206     if (api::utils::isa<api::CNode>(input_node)) {
207       MS_LOG(INFO) << "cnode don't have blobs";
208       continue;
209     }
210     if (api::utils::isa<api::ParameterPtr>(input_node)) {
211       auto input_param_node = input_node->cast<api::ParameterPtr>();
212       if (!input_param_node->has_default()) {
213         MS_LOG(INFO) << "graph input don't have blobs";
214         continue;
215       }
216       auto tensor_info = input_param_node->default_param()->cast<api::TensorPtr>();
217       if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
218         auto raw_datas = static_cast<float *>(tensor_info->data());
219         auto elem_count = tensor_info->DataSize();
220         auto weight_data = new (std::nothrow) float[tensor_info->DataSize()];
221         if (weight_data == nullptr) {
222           MS_LOG(ERROR) << "new float[] failed.";
223           return RET_ERROR;
224         }
225         if (memcpy_s(weight_data, static_cast<size_t>(tensor_info->DataSize()) * sizeof(float), raw_datas,
226                      static_cast<size_t>(tensor_info->DataSize()) * sizeof(float)) != EOK) {
227           MS_LOG(ERROR) << "memcpy_s failed.";
228           delete[] weight_data;
229           return RET_ERROR;
230         }
231         recurrent_operator->AddRecurrentParamVec(weight_data);
232         recurrent_operator->AddRecurrentParamLengthVec(elem_count);
233       } else {
234         MS_LOG(ERROR) << "tensor_info is nullptr, or DataSize equals zero. " << cnode->fullname_with_scope();
235         return RET_ERROR;
236       }
237     }
238   }
239   return RET_OK;
240 }
SetRecurrentOnnxInfo(const api::CNodePtr & cnode,mapper::RecurrentOperator * recurrent_operator)241 STATUS SetRecurrentOnnxInfo(const api::CNodePtr &cnode, mapper::RecurrentOperator *recurrent_operator) {
242   if (recurrent_operator == nullptr) {
243     MS_LOG(ERROR) << "recurrent_operator is nullptr.";
244     return RET_ERROR;
245   }
246   for (size_t i = 1; i < cnode->size(); i++) {
247     auto input_node = cnode->input(i);
248     if (api::utils::isa<api::CNode>(input_node)) {
249       MS_LOG(INFO) << "cnode don't have blobs";
250       continue;
251     }
252     if (api::utils::isa<api::ParameterPtr>(input_node)) {
253       auto input_param_node = input_node->cast<api::ParameterPtr>();
254       if (!input_param_node->has_default()) {
255         MS_LOG(INFO) << "graph input don't have blobs";
256         continue;
257       }
258       auto tensor_info = input_param_node->default_param()->cast<api::TensorPtr>();
259       if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
260         auto raw_datas = static_cast<float *>(tensor_info->data());
261         auto shape = tensor_info->shape();
262         vector<int32_t> shape_vec(shape.begin(), shape.end());
263         auto weight_data = new (std::nothrow) float[tensor_info->DataSize()];
264         if (weight_data == nullptr) {
265           MS_LOG(ERROR) << "new float[] failed.";
266           return RET_ERROR;
267         }
268         if (memcpy_s(weight_data, static_cast<size_t>(tensor_info->DataSize()) * sizeof(float), raw_datas,
269                      static_cast<size_t>(tensor_info->DataSize()) * sizeof(float)) != EOK) {
270           MS_LOG(ERROR) << "memcpy_s failed.";
271           delete[] weight_data;
272           return RET_ERROR;
273         }
274         if (SetOnnxLstmOffLineArgs(recurrent_operator, i, shape_vec, weight_data) != RET_OK) {
275           MS_LOG(ERROR) << "set offline args failed.";
276           return RET_ERROR;
277         }
278         if (i == kDims5) {
279           std::vector<std::pair<std::vector<float>, std::vector<int32_t>>> offline_args;
280           std::vector<float> offline_data;
281           recurrent_operator->PushOfflineArgs({});
282           if (CheckTensorInfoType(tensor_info, &offline_data) != RET_OK) {
283             MS_LOG(ERROR) << "check tensor_info type failed.";
284             return RET_ERROR;
285           }
286           std::vector<int32_t> offline_shape;
287           ShapeVector shape_vector;
288           if (GetShapeVectorFromParameter(input_param_node, &shape_vector) != RET_OK) {
289             MS_LOG(ERROR) << "get shape vector from parameter failed. " << input_param_node->fullname_with_scope();
290             return RET_ERROR;
291           }
292           (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(offline_shape),
293                                [](const int64_t dim) { return static_cast<int32_t>(dim); });
294           (void)offline_args.emplace_back(std::make_pair(offline_data, offline_shape));
295           for (auto &offline_arg : offline_args) {
296             recurrent_operator->PushOfflineArgs(std::move(offline_arg));
297           }
298         }
299       } else {
300         MS_LOG(ERROR) << "tensor_info is nullptr, or DataSize equals zero. " << cnode->fullname_with_scope();
301         return RET_ERROR;
302       }
303     }
304   }
305   return RET_OK;
306 }
CheckTensorInfoType(const api::TensorPtr & tensor_info,std::vector<float> * offline_data)307 STATUS CheckTensorInfoType(const api::TensorPtr &tensor_info, std::vector<float> *offline_data) {
308   auto elem_count = tensor_info->DataSize();
309   if (tensor_info->data_type() == kNumberTypeInt32 || tensor_info->data_type() == kNumberTypeInt) {
310     auto raw_data = static_cast<int32_t *>(tensor_info->data());
311     *offline_data = std::vector<float>(raw_data, raw_data + elem_count);
312   } else if (tensor_info->data_type() == kNumberTypeFloat32 || tensor_info->data_type() == kNumberTypeFloat) {
313     auto raw_data = static_cast<float *>(tensor_info->data());
314     *offline_data = std::vector<float>(raw_data, raw_data + elem_count);
315   } else {
316     MS_LOG(ERROR) << "unsupported param type. " << tensor_info->data_type();
317     return RET_ERROR;
318   }
319   return RET_OK;
320 }
SetOnnxLstmOffLineArgs(mapper::RecurrentOperator * recurrent_operator,size_t index,const vector<int32_t> & shape_vec,const float * data)321 STATUS SetOnnxLstmOffLineArgs(mapper::RecurrentOperator *recurrent_operator, size_t index,
322                               const vector<int32_t> &shape_vec, const float *data) {
323   if (index == kDims2) {
324     recurrent_operator->SetXtShapeVec(shape_vec);
325     recurrent_operator->SetXtWeightDataPtr(data);
326   } else if (index == kDims3) {
327     recurrent_operator->SetHtShapeVec(shape_vec);
328     recurrent_operator->SetHtWeightDataPtr(data);
329   } else if (index == kDims4) {
330     recurrent_operator->SetRecurrentBiasShapeVec(shape_vec);
331     recurrent_operator->SetRecurrentBiasDataPtr(data);
332   } else if (index == kDims8) {
333     recurrent_operator->SetPeepholesShapeVec(shape_vec);
334     recurrent_operator->SetPeepholesWeightDataPtr(data);
335   }
336   return RET_OK;
337 }
PushOfflineArgs(const api::CNodePtr & cnode,mapper::BaseOperator * base_operator,size_t offline_args_size)338 STATUS PushOfflineArgs(const api::CNodePtr &cnode, mapper::BaseOperator *base_operator, size_t offline_args_size) {
339   if (base_operator == nullptr) {
340     MS_LOG(ERROR) << "base_operator is nullptr.";
341     return RET_ERROR;
342   }
343   if (offline_args_size > cnode->size()) {
344     MS_LOG(ERROR) << "input offline_args_size:" << offline_args_size
345                   << " is greater than cnode input size:" << cnode->size() << " " << cnode->fullname_with_scope();
346     return RET_ERROR;
347   }
348   auto inputs_size = std::min(offline_args_size + 1, cnode->size());
349   std::vector<std::pair<std::vector<float>, std::vector<int32_t>>> offline_args;
350   bool has_offline_args = false;
351   for (size_t i = 1; i < inputs_size; i++) {
352     auto input_node = cnode->input(i);
353     if (api::utils::isa<api::CNode>(input_node)) {
354       MS_LOG(INFO) << "cnode don't have blobs";
355       (void)offline_args.emplace_back();
356       continue;
357     }
358     if (api::utils::isa<api::ParameterPtr>(input_node)) {
359       auto input_param_node = input_node->cast<api::ParameterPtr>();
360       if (!input_param_node->has_default()) {
361         MS_LOG(INFO) << "graph input don't have blobs";
362         (void)offline_args.emplace_back();
363         continue;
364       }
365       auto tensor_info = input_param_node->default_param()->cast<api::TensorPtr>();
366       if (tensor_info != nullptr && tensor_info->DataSize() != 0) {
367         has_offline_args = true;
368         std::vector<float> offline_data;
369         auto elem_count = tensor_info->DataSize();
370         if (tensor_info->data_type() == kNumberTypeInt32 || tensor_info->data_type() == kNumberTypeInt) {
371           auto raw_datas = static_cast<int32_t *>(tensor_info->data());
372           offline_data = std::vector<float>(raw_datas, raw_datas + elem_count);
373         } else if (tensor_info->data_type() == kNumberTypeFloat32 || tensor_info->data_type() == kNumberTypeFloat) {
374           auto raw_datas = static_cast<float *>(tensor_info->data());
375           offline_data = std::vector<float>(raw_datas, raw_datas + elem_count);
376         } else {
377           MS_LOG(ERROR) << "unsupported param type. " << tensor_info->data_type();
378           return RET_ERROR;
379         }
380         std::vector<int32_t> offline_shape;
381         ShapeVector shape_vector;
382         if (GetShapeVectorFromParameter(input_param_node, &shape_vector) != RET_OK) {
383           MS_LOG(ERROR) << "get shape vector from parameter failed. " << input_param_node->fullname_with_scope();
384           return RET_ERROR;
385         }
386         (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(offline_shape),
387                              [](const int64_t dim) { return static_cast<int32_t>(dim); });
388         (void)offline_args.emplace_back(std::make_pair(offline_data, offline_shape));
389       } else {
390         MS_LOG(ERROR) << "tensor_info is nullptr, or DataSize equals zero. " << cnode->fullname_with_scope();
391         return RET_ERROR;
392       }
393     }
394   }
395   if (has_offline_args) {
396     for (auto &offline_arg : offline_args) {
397       base_operator->PushOfflineArgs(std::move(offline_arg));
398     }
399   }
400   return RET_OK;
401 }
402 }  // namespace dpico
403 }  // namespace mindspore
404