1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 #include <map>
22 #include <memory>
23
24 #include "src/extendrt/utils/func_graph_utils.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "include/common/utils/convert_utils.h"
29 #include "mindspore/ccsrc/include/backend/optimizer/helper.h"
30
31 #include "ops/op_name.h"
32 #include "tools/optimizer/format/to_nhwc_format.h"
33 #include "tools/optimizer/graph/decrease_transpose_algo.h"
34
35 namespace mindspore {
36 const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
GetNodeValuePtr(AnfNodePtr input_node)37 ValuePtr FuncGraphUtils::GetNodeValuePtr(AnfNodePtr input_node) {
38 if (input_node == nullptr) {
39 return nullptr;
40 }
41 if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
42 input_node = AnfUtils::VisitKernel(input_node, 0).first;
43 }
44 ValuePtr value = nullptr;
45 if (input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
46 auto value_node = input_node->cast<ValueNodePtr>();
47 if (value_node) {
48 value = value_node->value();
49 }
50 } else if (input_node->isa<Parameter>()) {
51 auto parameter = input_node->cast<ParameterPtr>();
52 if (parameter->has_default()) {
53 value = parameter->default_param();
54 }
55 }
56 return value;
57 }
58
GetConstNodeValue(AnfNodePtr input_node)59 tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) {
60 ValuePtr value = GetNodeValuePtr(input_node);
61 if (value == nullptr) {
62 return nullptr;
63 }
64 if (value->isa<tensor::Tensor>()) {
65 auto tensor = value->cast<tensor::TensorPtr>();
66 if (tensor == nullptr || tensor->data().const_data() == nullptr) {
67 return nullptr;
68 }
69 return tensor;
70 }
71 if (value->isa<Scalar>()) {
72 return ScalarToTensor(value->cast<ScalarPtr>());
73 }
74 if (value->isa<ValueTuple>()) {
75 return opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
76 }
77 if (value->isa<Type>()) {
78 auto type_ptr = value->cast<TypePtr>();
79 if (type_ptr == nullptr) {
80 return nullptr;
81 }
82 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(type_ptr->type_id()), type_ptr->type());
83 }
84 MS_LOG(WARNING) << "Unexpected value type " << value->type_name() << " for " << input_node->fullname_with_scope();
85 return nullptr;
86 }
87
GetCNodeOperator(const mindspore::CNodePtr & cnode,mindspore::kernel::BaseOperatorPtr * base_operator)88 bool FuncGraphUtils::GetCNodeOperator(const mindspore::CNodePtr &cnode,
89 mindspore::kernel::BaseOperatorPtr *base_operator) {
90 if (!cnode || !base_operator) {
91 MS_LOG(ERROR) << "Input cnode or base_operator cannot be nullptr";
92 return false;
93 }
94 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
95 MS_EXCEPTION_IF_NULL(prim);
96 if (!prim) {
97 MS_LOG(ERROR) << "Primitive of cnode " << cnode->fullname_with_scope() << " cannot be nullptr";
98 return false;
99 }
100 auto kernel_name = prim->name();
101 ops::PrimitiveCPtr primc_ptr = nullptr;
102 static auto &primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
103 auto primc_it = primc_fns.find(kernel_name);
104 if (primc_it != primc_fns.end() && primc_it->second) {
105 primc_ptr = primc_it->second();
106 }
107 if (primc_ptr == nullptr) {
108 MS_LOG(ERROR) << "OpPrimCRegister can not find " << kernel_name;
109 return false;
110 }
111 (void)primc_ptr->SetAttrs(prim->attrs());
112
113 *base_operator = nullptr;
114 static auto &operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
115 auto op_it = operator_fns.find(kernel_name);
116 if (op_it != operator_fns.end() && op_it->second) {
117 *base_operator = op_it->second(primc_ptr);
118 }
119 if (*base_operator == nullptr) {
120 MS_LOG(ERROR) << "Failed to create operator of type " << kernel_name;
121 return false;
122 }
123 return true;
124 }
125
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)126 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
127 if (node == nullptr || primitive_type == nullptr) {
128 return false;
129 }
130 if (node->isa<CNode>()) {
131 auto cnode = node->cast<CNodePtr>();
132 return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
133 } else if (node->isa<ValueNode>()) {
134 return IsPrimitive(node, primitive_type);
135 }
136 return false;
137 }
138
GetNodeInputs(const AnfNodePtr & anf_node)139 std::vector<common::KernelWithIndex> FuncGraphUtils::GetNodeInputs(const AnfNodePtr &anf_node) {
140 if (anf_node == nullptr) {
141 return {};
142 }
143 if (!anf_node->isa<CNode>()) {
144 return {{anf_node, 0}};
145 }
146 auto cnode = anf_node->cast<CNodePtr>();
147 std::vector<common::KernelWithIndex> inputs;
148 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
149 for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
150 const auto &pre_node_output = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
151 auto pre_node = pre_node_output.first;
152 if (CheckPrimitiveType(pre_node, prim::kPrimMakeTuple) || CheckPrimitiveType(pre_node, kPrimMakeTupleV2)) {
153 auto tuple_inputs = GetNodeInputs(pre_node);
154 std::copy(tuple_inputs.begin(), tuple_inputs.end(), std::back_inserter(inputs));
155 } else if (CheckPrimitiveType(pre_node, prim::kPrimSplit) &&
156 CheckPrimitiveType(cnode->input(1), prim::kPrimSplit)) {
157 inputs = common::AnfAlgo::GetAllOutputWithIndex(pre_node);
158 } else {
159 inputs.push_back(pre_node_output);
160 }
161 }
162 return inputs;
163 }
164
GetCNodeInputsOutputs(const mindspore::CNodePtr & cnode,std::vector<AnfWithOutIndex> * input_tensors,std::vector<AnfWithOutIndex> * output_tensors)165 bool FuncGraphUtils::GetCNodeInputsOutputs(const mindspore::CNodePtr &cnode,
166 std::vector<AnfWithOutIndex> *input_tensors,
167 std::vector<AnfWithOutIndex> *output_tensors) {
168 if (!cnode || !input_tensors || !output_tensors) {
169 MS_LOG(ERROR) << "Input cnode, input_tensors or output_tensors cannot be nullptr";
170 return false;
171 }
172 // Makeup input tensors.
173 *input_tensors = GetNodeInputs(cnode);
174 // Makeup output tensors.
175 output_tensors->clear();
176 auto output_num = AnfUtils::GetOutputTensorNum(cnode);
177 for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
178 session::KernelWithIndex tensor_id = {cnode, output_idx};
179 output_tensors->push_back(tensor_id);
180 }
181 return true;
182 }
183
GetFuncGraphInputs(const FuncGraphPtr & func_graph,std::vector<AnfWithOutIndex> * inputs)184 bool FuncGraphUtils::GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *inputs) {
185 if (!func_graph || !inputs) {
186 MS_LOG(ERROR) << "Input func_graph or inputs cannot be nullptr";
187 return false;
188 }
189 auto graph_inputs = func_graph->get_inputs();
190 // find parameters of graph inputs
191 for (size_t i = 0; i < graph_inputs.size(); ++i) {
192 auto input = graph_inputs[i];
193 if (input == nullptr) {
194 MS_LOG(ERROR) << "Input " << i << " of FuncGraph is nullptr.";
195 return false;
196 }
197 auto parameter = input->cast<ParameterPtr>();
198 if (!parameter) {
199 MS_LOG(ERROR) << "Input " << input->fullname_with_scope() << " of FuncGraph is not type of Parameter.";
200 return false;
201 }
202 if (common::AnfAlgo::IsParameterWeight(parameter)) {
203 continue;
204 }
205 inputs->push_back(std::make_pair(input, 0));
206 }
207 return true;
208 }
209
GetFuncGraphOutputs(const FuncGraphPtr & func_graph,std::vector<AnfWithOutIndex> * outputs)210 bool FuncGraphUtils::GetFuncGraphOutputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *outputs) {
211 if (func_graph == nullptr) {
212 MS_LOG(ERROR) << "Input func_graph cannot be nullptr!";
213 return false;
214 }
215
216 if (outputs == nullptr) {
217 MS_LOG(ERROR) << "Outputs cannot be nullptr!";
218 return false;
219 }
220
221 *outputs = GetNodeInputs(func_graph->get_return());
222 return true;
223 }
224
GetTensorDataType(const AnfWithOutIndex & tensor)225 DataType FuncGraphUtils::GetTensorDataType(const AnfWithOutIndex &tensor) {
226 auto node = tensor.first;
227 auto output_idx = tensor.second;
228 auto tensor_val = GetConstNodeValue(node);
229 TypeId type_id;
230 if (tensor_val) {
231 type_id = tensor_val->Dtype()->type_id();
232 } else {
233 type_id = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
234 }
235 return static_cast<enum DataType>(type_id);
236 }
237
GetTensorShape(const AnfWithOutIndex & tensor)238 ShapeVector FuncGraphUtils::GetTensorShape(const AnfWithOutIndex &tensor) {
239 auto node = tensor.first;
240 auto output_idx = tensor.second;
241 auto tensor_val = GetConstNodeValue(node);
242 ShapeVector shape;
243 if (tensor_val) {
244 shape = tensor_val->shape_c();
245 } else {
246 shape = common::AnfAlgo::GetOutputInferShape(node, output_idx);
247 }
248 return shape;
249 }
250
UnifyGraphToNHWCFormat(const FuncGraphPtr & graph)251 Status FuncGraphUtils::UnifyGraphToNHWCFormat(const FuncGraphPtr &graph) {
252 auto value = graph->get_attr(ops::kFormat);
253 if (value != nullptr && GetValue<int64_t>(value) != mindspore::NHWC) {
254 auto format_pass = std::make_shared<opt::ToNHWCFormat>();
255 MS_CHECK_TRUE_RET(format_pass != nullptr, kLiteNullptr);
256 if (!format_pass->Run(graph)) {
257 MS_LOG(ERROR) << "DefaultGraphCompiler::Partition Run ToNHWCFormat pass failed";
258 return kLiteNullptr;
259 }
260 auto transpose_pass = std::make_shared<opt::DecreaseTransposeAlgo>();
261 MS_CHECK_TRUE_RET(transpose_pass != nullptr, kLiteNullptr);
262 if (!transpose_pass->Run(graph)) {
263 MS_LOG(ERROR) << "DefaultGraphCompiler::Partition Run DecreaseTransposeAlgo pass failed";
264 return kLiteNullptr;
265 }
266 }
267 return kSuccess;
268 }
269
GetTensorName(const AnfWithOutIndex & tensor)270 std::string FuncGraphUtils::GetTensorName(const AnfWithOutIndex &tensor) {
271 auto node = tensor.first;
272 auto idx = tensor.second;
273 MS_EXCEPTION_IF_NULL(node);
274 AbstractBasePtr abstract = node->abstract();
275 MS_EXCEPTION_IF_NULL(abstract);
276 if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
277 auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract);
278 MS_EXCEPTION_IF_NULL(abstract_tuple);
279 auto abstract_list = abstract_tuple->elements();
280 if (abstract_list.size() <= idx) {
281 MS_LOG(ERROR) << "AbstractTuple's size[" << abstract_list.size() << "] is smaller than expect size[" << idx
282 << "]";
283 return "";
284 }
285 abstract = abstract_list[idx];
286 MS_EXCEPTION_IF_NULL(abstract);
287 }
288 MS_EXCEPTION_IF_NULL(abstract);
289 std::string output_name;
290 if (!abstract->name().empty()) {
291 output_name = abstract->name();
292 } else if (idx > 0) {
293 output_name = node->fullname_with_scope() + ":" + std::to_string(idx);
294 } else {
295 output_name = node->fullname_with_scope();
296 }
297 return output_name;
298 }
299
GetAbstract(const AnfWithOutIndex & tensor)300 AbstractBasePtr FuncGraphUtils::GetAbstract(const AnfWithOutIndex &tensor) {
301 auto node = tensor.first;
302 auto idx = tensor.second;
303 MS_EXCEPTION_IF_NULL(node);
304 AbstractBasePtr abstract = node->abstract();
305 MS_EXCEPTION_IF_NULL(abstract);
306 return common::AnfAlgo::FetchAbstractByIndex(node->abstract(), idx);
307 }
308
GetFuncGraphInputsInfo(const FuncGraphPtr & func_graph,std::vector<tensor::TensorPtr> * inputs,std::vector<std::string> * inputs_name)309 void FuncGraphUtils::GetFuncGraphInputsInfo(const FuncGraphPtr &func_graph, std::vector<tensor::TensorPtr> *inputs,
310 std::vector<std::string> *inputs_name) {
311 MS_EXCEPTION_IF_NULL(func_graph);
312 MS_EXCEPTION_IF_NULL(inputs);
313 MS_EXCEPTION_IF_NULL(inputs_name);
314 std::vector<AnfWithOutIndex> input_idxs;
315 if (!GetFuncGraphInputs(func_graph, &input_idxs)) {
316 MS_LOG(ERROR) << "Failed to get input infos from graph";
317 return;
318 }
319 inputs->clear();
320 inputs_name->clear();
321 for (auto &tensor : input_idxs) {
322 auto name = FuncGraphUtils::GetTensorName(tensor);
323 auto data_type = FuncGraphUtils::GetTensorDataType(tensor);
324 auto shape = FuncGraphUtils::GetTensorShape(tensor);
325 auto ms_tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_type), shape);
326 ms_tensor->set_name(name);
327 inputs->push_back(ms_tensor);
328 inputs_name->push_back(name);
329 }
330 }
331
GetFuncGraphOutputsInfo(const FuncGraphPtr & func_graph,std::vector<tensor::TensorPtr> * outputs,std::vector<std::string> * output_names)332 void FuncGraphUtils::GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std::vector<tensor::TensorPtr> *outputs,
333 std::vector<std::string> *output_names) {
334 MS_EXCEPTION_IF_NULL(func_graph);
335 MS_EXCEPTION_IF_NULL(outputs);
336 MS_EXCEPTION_IF_NULL(output_names);
337 std::vector<AnfWithOutIndex> output_idxs;
338 if (!GetFuncGraphOutputs(func_graph, &output_idxs)) {
339 MS_LOG(ERROR) << "Failed to get input infos from graph";
340 return;
341 }
342 outputs->clear();
343 output_names->clear();
344 for (auto &tensor : output_idxs) {
345 auto name = FuncGraphUtils::GetTensorName(tensor);
346 auto data_type = FuncGraphUtils::GetTensorDataType(tensor);
347 auto shape = FuncGraphUtils::GetTensorShape(tensor);
348 auto ms_tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_type), shape);
349 ms_tensor->set_name(name);
350 outputs->push_back(ms_tensor);
351 output_names->push_back(name);
352 }
353 }
354
TransformSegmentToAnfGraph(const AnfNodePtrList & lst)355 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> FuncGraphUtils::TransformSegmentToAnfGraph(
356 const AnfNodePtrList &lst) {
357 if (lst.empty()) {
358 MS_LOG(EXCEPTION) << "Input anf node list is empty";
359 }
360 FuncGraphPtr fg = nullptr;
361 {
362 // limit the lifetime of guard.
363 MS_EXCEPTION_IF_NULL(lst[0]);
364 MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>());
365 MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>()->func_graph());
366 TraceGuard guard(std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
367 fg = std::make_shared<FuncGraph>();
368 }
369 AnfNodePtrList inputs;
370 mindspore::HashMap<AnfNodePtr, AnfNodePtr> eqv;
371 // Merge CNodes into a AnfGraph that represents a linear instruction segment
372 for (auto n : lst) {
373 MS_EXCEPTION_IF_NULL(n);
374 if (!n->isa<CNode>()) {
375 MS_LOG(EXCEPTION) << "Inst is not CNode";
376 }
377 auto &inps = n->cast<CNodePtr>()->inputs();
378 if (inps.empty()) {
379 MS_LOG(EXCEPTION) << "Input is empty";
380 }
381 if (!IsValueNode<Primitive>(inps[0]) &&
382 !(IsValueNode<FuncGraph>(inps[0]) &&
383 inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
384 MS_LOG(EXCEPTION) << "Input[0] must be a Primitive ValueNode";
385 }
386 auto fn = inps[0];
387 std::vector<AnfNodePtr> args{fn};
388 if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= kDependInputSize &&
389 eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
390 args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
391 const size_t value_start_index = 2;
392 for (size_t i = value_start_index; i < inps.size(); ++i) {
393 args.emplace_back(NewValueNode(MakeValue(0)));
394 }
395 } else {
396 (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
397 [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
398 }
399 TraceGuard tg(std::make_shared<TraceSegmentTransform>(n->debug_info()));
400 MS_EXCEPTION_IF_NULL(fg);
401 eqv[n] = fg->NewCNode(args);
402 eqv[n]->set_abstract(n->abstract());
403 eqv[n]->set_kernel_info(n->kernel_info_ptr());
404 }
405 mindspore::HashSet<AnfNodePtr> eqv_keys;
406 for (auto &e : eqv) {
407 (void)eqv_keys.emplace(e.first);
408 }
409 auto mgr = lst[0]->func_graph()->manager();
410 MS_EXCEPTION_IF_NULL(mgr);
411 auto outputs = GetOutput(lst, mgr->node_users(), eqv_keys);
412 AnfNodePtr fg_output;
413 if (outputs.size() > 1) {
414 std::vector<AnfNodePtr> output_args;
415 output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
416 (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
417 [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
418 // Set output for AnfGraph
419 fg_output = fg->NewCNode(output_args);
420 } else {
421 if (outputs.empty()) {
422 MS_LOG(EXCEPTION) << "Output is empty.";
423 }
424 fg_output = eqv[outputs[0]];
425 }
426 fg->set_output(fg_output);
427 return std::make_tuple(fg, inputs, outputs);
428 }
429
GetOutput(const AnfNodePtrList & nodes,const NodeUsersMap & users,const mindspore::HashSet<AnfNodePtr> & seen)430 AnfNodePtrList FuncGraphUtils::GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
431 const mindspore::HashSet<AnfNodePtr> &seen) {
432 AnfNodePtrList output;
433 if (users.size() == 0) {
434 return output;
435 }
436 for (auto &node : nodes) {
437 MS_EXCEPTION_IF_NULL(node);
438 if (!node->isa<CNode>()) {
439 continue;
440 }
441 auto iter = users.find(node);
442 if (iter == users.end()) {
443 continue;
444 }
445 auto &node_users = iter->second;
446 const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
447 [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
448 const bool is_outer_user = (seen.find(u.first) == seen.end());
449 return is_outer_user;
450 });
451 if (has_outer_user) {
452 output.emplace_back(node);
453 }
454 }
455 return output;
456 }
457
RefSubGraphNode(const FuncGraphPtr & fg,const AnfNodePtr & node,AnfNodePtrList * inputs_ptr,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * eqv_ptr)458 AnfNodePtr FuncGraphUtils::RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
459 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *eqv_ptr) {
460 MS_EXCEPTION_IF_NULL(fg);
461 MS_EXCEPTION_IF_NULL(inputs_ptr);
462 MS_EXCEPTION_IF_NULL(eqv_ptr);
463 MS_EXCEPTION_IF_NULL(node);
464 auto &inputs = *inputs_ptr;
465 auto &eqv = *eqv_ptr;
466 if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
467 eqv[node] = node;
468 } else if (eqv.find(node) == eqv.end()) {
469 inputs.push_back(node);
470 eqv[node] = fg->add_parameter();
471 eqv[node]->set_abstract(node->abstract());
472 eqv[node]->set_kernel_info(node->kernel_info_ptr());
473 }
474 return eqv[node];
475 }
476 } // namespace mindspore
477