• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 #include <map>
22 #include <memory>
23 
24 #include "src/extendrt/utils/func_graph_utils.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "include/common/utils/convert_utils.h"
29 #include "mindspore/ccsrc/include/backend/optimizer/helper.h"
30 
31 #include "ops/op_name.h"
32 #include "tools/optimizer/format/to_nhwc_format.h"
33 #include "tools/optimizer/graph/decrease_transpose_algo.h"
34 
35 namespace mindspore {
36 const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
GetNodeValuePtr(AnfNodePtr input_node)37 ValuePtr FuncGraphUtils::GetNodeValuePtr(AnfNodePtr input_node) {
38   if (input_node == nullptr) {
39     return nullptr;
40   }
41   if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
42     input_node = AnfUtils::VisitKernel(input_node, 0).first;
43   }
44   ValuePtr value = nullptr;
45   if (input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
46     auto value_node = input_node->cast<ValueNodePtr>();
47     if (value_node) {
48       value = value_node->value();
49     }
50   } else if (input_node->isa<Parameter>()) {
51     auto parameter = input_node->cast<ParameterPtr>();
52     if (parameter->has_default()) {
53       value = parameter->default_param();
54     }
55   }
56   return value;
57 }
58 
GetConstNodeValue(AnfNodePtr input_node)59 tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) {
60   ValuePtr value = GetNodeValuePtr(input_node);
61   if (value == nullptr) {
62     return nullptr;
63   }
64   if (value->isa<tensor::Tensor>()) {
65     auto tensor = value->cast<tensor::TensorPtr>();
66     if (tensor == nullptr || tensor->data().const_data() == nullptr) {
67       return nullptr;
68     }
69     return tensor;
70   }
71   if (value->isa<Scalar>()) {
72     return ScalarToTensor(value->cast<ScalarPtr>());
73   }
74   if (value->isa<ValueTuple>()) {
75     return opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
76   }
77   if (value->isa<Type>()) {
78     auto type_ptr = value->cast<TypePtr>();
79     if (type_ptr == nullptr) {
80       return nullptr;
81     }
82     return std::make_shared<tensor::Tensor>(static_cast<int64_t>(type_ptr->type_id()), type_ptr->type());
83   }
84   MS_LOG(WARNING) << "Unexpected value type " << value->type_name() << " for " << input_node->fullname_with_scope();
85   return nullptr;
86 }
87 
GetCNodeOperator(const mindspore::CNodePtr & cnode,mindspore::kernel::BaseOperatorPtr * base_operator)88 bool FuncGraphUtils::GetCNodeOperator(const mindspore::CNodePtr &cnode,
89                                       mindspore::kernel::BaseOperatorPtr *base_operator) {
90   if (!cnode || !base_operator) {
91     MS_LOG(ERROR) << "Input cnode or base_operator cannot be nullptr";
92     return false;
93   }
94   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
95   MS_EXCEPTION_IF_NULL(prim);
96   if (!prim) {
97     MS_LOG(ERROR) << "Primitive of cnode " << cnode->fullname_with_scope() << " cannot be nullptr";
98     return false;
99   }
100   auto kernel_name = prim->name();
101   ops::PrimitiveCPtr primc_ptr = nullptr;
102   static auto &primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
103   auto primc_it = primc_fns.find(kernel_name);
104   if (primc_it != primc_fns.end() && primc_it->second) {
105     primc_ptr = primc_it->second();
106   }
107   if (primc_ptr == nullptr) {
108     MS_LOG(ERROR) << "OpPrimCRegister can not find " << kernel_name;
109     return false;
110   }
111   (void)primc_ptr->SetAttrs(prim->attrs());
112 
113   *base_operator = nullptr;
114   static auto &operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
115   auto op_it = operator_fns.find(kernel_name);
116   if (op_it != operator_fns.end() && op_it->second) {
117     *base_operator = op_it->second(primc_ptr);
118   }
119   if (*base_operator == nullptr) {
120     MS_LOG(ERROR) << "Failed to create operator of type " << kernel_name;
121     return false;
122   }
123   return true;
124 }
125 
CheckPrimitiveType(const AnfNodePtr & node,const PrimitivePtr & primitive_type)126 bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
127   if (node == nullptr || primitive_type == nullptr) {
128     return false;
129   }
130   if (node->isa<CNode>()) {
131     auto cnode = node->cast<CNodePtr>();
132     return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
133   } else if (node->isa<ValueNode>()) {
134     return IsPrimitive(node, primitive_type);
135   }
136   return false;
137 }
138 
GetNodeInputs(const AnfNodePtr & anf_node)139 std::vector<common::KernelWithIndex> FuncGraphUtils::GetNodeInputs(const AnfNodePtr &anf_node) {
140   if (anf_node == nullptr) {
141     return {};
142   }
143   if (!anf_node->isa<CNode>()) {
144     return {{anf_node, 0}};
145   }
146   auto cnode = anf_node->cast<CNodePtr>();
147   std::vector<common::KernelWithIndex> inputs;
148   size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
149   for (size_t input_idx = 0; input_idx < input_num; ++input_idx) {
150     const auto &pre_node_output = common::AnfAlgo::GetPrevNodeOutput(cnode, input_idx);
151     auto pre_node = pre_node_output.first;
152     if (CheckPrimitiveType(pre_node, prim::kPrimMakeTuple) || CheckPrimitiveType(pre_node, kPrimMakeTupleV2)) {
153       auto tuple_inputs = GetNodeInputs(pre_node);
154       std::copy(tuple_inputs.begin(), tuple_inputs.end(), std::back_inserter(inputs));
155     } else if (CheckPrimitiveType(pre_node, prim::kPrimSplit) &&
156                CheckPrimitiveType(cnode->input(1), prim::kPrimSplit)) {
157       inputs = common::AnfAlgo::GetAllOutputWithIndex(pre_node);
158     } else {
159       inputs.push_back(pre_node_output);
160     }
161   }
162   return inputs;
163 }
164 
GetCNodeInputsOutputs(const mindspore::CNodePtr & cnode,std::vector<AnfWithOutIndex> * input_tensors,std::vector<AnfWithOutIndex> * output_tensors)165 bool FuncGraphUtils::GetCNodeInputsOutputs(const mindspore::CNodePtr &cnode,
166                                            std::vector<AnfWithOutIndex> *input_tensors,
167                                            std::vector<AnfWithOutIndex> *output_tensors) {
168   if (!cnode || !input_tensors || !output_tensors) {
169     MS_LOG(ERROR) << "Input cnode, input_tensors or output_tensors cannot be nullptr";
170     return false;
171   }
172   // Makeup input tensors.
173   *input_tensors = GetNodeInputs(cnode);
174   // Makeup output tensors.
175   output_tensors->clear();
176   auto output_num = AnfUtils::GetOutputTensorNum(cnode);
177   for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
178     session::KernelWithIndex tensor_id = {cnode, output_idx};
179     output_tensors->push_back(tensor_id);
180   }
181   return true;
182 }
183 
GetFuncGraphInputs(const FuncGraphPtr & func_graph,std::vector<AnfWithOutIndex> * inputs)184 bool FuncGraphUtils::GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *inputs) {
185   if (!func_graph || !inputs) {
186     MS_LOG(ERROR) << "Input func_graph or inputs cannot be nullptr";
187     return false;
188   }
189   auto graph_inputs = func_graph->get_inputs();
190   // find parameters of graph inputs
191   for (size_t i = 0; i < graph_inputs.size(); ++i) {
192     auto input = graph_inputs[i];
193     if (input == nullptr) {
194       MS_LOG(ERROR) << "Input " << i << " of FuncGraph is nullptr.";
195       return false;
196     }
197     auto parameter = input->cast<ParameterPtr>();
198     if (!parameter) {
199       MS_LOG(ERROR) << "Input " << input->fullname_with_scope() << " of FuncGraph is not type of Parameter.";
200       return false;
201     }
202     if (common::AnfAlgo::IsParameterWeight(parameter)) {
203       continue;
204     }
205     inputs->push_back(std::make_pair(input, 0));
206   }
207   return true;
208 }
209 
GetFuncGraphOutputs(const FuncGraphPtr & func_graph,std::vector<AnfWithOutIndex> * outputs)210 bool FuncGraphUtils::GetFuncGraphOutputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *outputs) {
211   if (func_graph == nullptr) {
212     MS_LOG(ERROR) << "Input func_graph cannot be nullptr!";
213     return false;
214   }
215 
216   if (outputs == nullptr) {
217     MS_LOG(ERROR) << "Outputs cannot be nullptr!";
218     return false;
219   }
220 
221   *outputs = GetNodeInputs(func_graph->get_return());
222   return true;
223 }
224 
GetTensorDataType(const AnfWithOutIndex & tensor)225 DataType FuncGraphUtils::GetTensorDataType(const AnfWithOutIndex &tensor) {
226   auto node = tensor.first;
227   auto output_idx = tensor.second;
228   auto tensor_val = GetConstNodeValue(node);
229   TypeId type_id;
230   if (tensor_val) {
231     type_id = tensor_val->Dtype()->type_id();
232   } else {
233     type_id = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
234   }
235   return static_cast<enum DataType>(type_id);
236 }
237 
GetTensorShape(const AnfWithOutIndex & tensor)238 ShapeVector FuncGraphUtils::GetTensorShape(const AnfWithOutIndex &tensor) {
239   auto node = tensor.first;
240   auto output_idx = tensor.second;
241   auto tensor_val = GetConstNodeValue(node);
242   ShapeVector shape;
243   if (tensor_val) {
244     shape = tensor_val->shape_c();
245   } else {
246     shape = common::AnfAlgo::GetOutputInferShape(node, output_idx);
247   }
248   return shape;
249 }
250 
UnifyGraphToNHWCFormat(const FuncGraphPtr & graph)251 Status FuncGraphUtils::UnifyGraphToNHWCFormat(const FuncGraphPtr &graph) {
252   auto value = graph->get_attr(ops::kFormat);
253   if (value != nullptr && GetValue<int64_t>(value) != mindspore::NHWC) {
254     auto format_pass = std::make_shared<opt::ToNHWCFormat>();
255     MS_CHECK_TRUE_RET(format_pass != nullptr, kLiteNullptr);
256     if (!format_pass->Run(graph)) {
257       MS_LOG(ERROR) << "DefaultGraphCompiler::Partition Run ToNHWCFormat pass failed";
258       return kLiteNullptr;
259     }
260     auto transpose_pass = std::make_shared<opt::DecreaseTransposeAlgo>();
261     MS_CHECK_TRUE_RET(transpose_pass != nullptr, kLiteNullptr);
262     if (!transpose_pass->Run(graph)) {
263       MS_LOG(ERROR) << "DefaultGraphCompiler::Partition Run DecreaseTransposeAlgo pass failed";
264       return kLiteNullptr;
265     }
266   }
267   return kSuccess;
268 }
269 
GetTensorName(const AnfWithOutIndex & tensor)270 std::string FuncGraphUtils::GetTensorName(const AnfWithOutIndex &tensor) {
271   auto node = tensor.first;
272   auto idx = tensor.second;
273   MS_EXCEPTION_IF_NULL(node);
274   AbstractBasePtr abstract = node->abstract();
275   MS_EXCEPTION_IF_NULL(abstract);
276   if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
277     auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract);
278     MS_EXCEPTION_IF_NULL(abstract_tuple);
279     auto abstract_list = abstract_tuple->elements();
280     if (abstract_list.size() <= idx) {
281       MS_LOG(ERROR) << "AbstractTuple's size[" << abstract_list.size() << "] is smaller than expect size[" << idx
282                     << "]";
283       return "";
284     }
285     abstract = abstract_list[idx];
286     MS_EXCEPTION_IF_NULL(abstract);
287   }
288   MS_EXCEPTION_IF_NULL(abstract);
289   std::string output_name;
290   if (!abstract->name().empty()) {
291     output_name = abstract->name();
292   } else if (idx > 0) {
293     output_name = node->fullname_with_scope() + ":" + std::to_string(idx);
294   } else {
295     output_name = node->fullname_with_scope();
296   }
297   return output_name;
298 }
299 
GetAbstract(const AnfWithOutIndex & tensor)300 AbstractBasePtr FuncGraphUtils::GetAbstract(const AnfWithOutIndex &tensor) {
301   auto node = tensor.first;
302   auto idx = tensor.second;
303   MS_EXCEPTION_IF_NULL(node);
304   AbstractBasePtr abstract = node->abstract();
305   MS_EXCEPTION_IF_NULL(abstract);
306   return common::AnfAlgo::FetchAbstractByIndex(node->abstract(), idx);
307 }
308 
GetFuncGraphInputsInfo(const FuncGraphPtr & func_graph,std::vector<tensor::TensorPtr> * inputs,std::vector<std::string> * inputs_name)309 void FuncGraphUtils::GetFuncGraphInputsInfo(const FuncGraphPtr &func_graph, std::vector<tensor::TensorPtr> *inputs,
310                                             std::vector<std::string> *inputs_name) {
311   MS_EXCEPTION_IF_NULL(func_graph);
312   MS_EXCEPTION_IF_NULL(inputs);
313   MS_EXCEPTION_IF_NULL(inputs_name);
314   std::vector<AnfWithOutIndex> input_idxs;
315   if (!GetFuncGraphInputs(func_graph, &input_idxs)) {
316     MS_LOG(ERROR) << "Failed to get input infos from graph";
317     return;
318   }
319   inputs->clear();
320   inputs_name->clear();
321   for (auto &tensor : input_idxs) {
322     auto name = FuncGraphUtils::GetTensorName(tensor);
323     auto data_type = FuncGraphUtils::GetTensorDataType(tensor);
324     auto shape = FuncGraphUtils::GetTensorShape(tensor);
325     auto ms_tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_type), shape);
326     ms_tensor->set_name(name);
327     inputs->push_back(ms_tensor);
328     inputs_name->push_back(name);
329   }
330 }
331 
GetFuncGraphOutputsInfo(const FuncGraphPtr & func_graph,std::vector<tensor::TensorPtr> * outputs,std::vector<std::string> * output_names)332 void FuncGraphUtils::GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std::vector<tensor::TensorPtr> *outputs,
333                                              std::vector<std::string> *output_names) {
334   MS_EXCEPTION_IF_NULL(func_graph);
335   MS_EXCEPTION_IF_NULL(outputs);
336   MS_EXCEPTION_IF_NULL(output_names);
337   std::vector<AnfWithOutIndex> output_idxs;
338   if (!GetFuncGraphOutputs(func_graph, &output_idxs)) {
339     MS_LOG(ERROR) << "Failed to get input infos from graph";
340     return;
341   }
342   outputs->clear();
343   output_names->clear();
344   for (auto &tensor : output_idxs) {
345     auto name = FuncGraphUtils::GetTensorName(tensor);
346     auto data_type = FuncGraphUtils::GetTensorDataType(tensor);
347     auto shape = FuncGraphUtils::GetTensorShape(tensor);
348     auto ms_tensor = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_type), shape);
349     ms_tensor->set_name(name);
350     outputs->push_back(ms_tensor);
351     output_names->push_back(name);
352   }
353 }
354 
TransformSegmentToAnfGraph(const AnfNodePtrList & lst)355 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> FuncGraphUtils::TransformSegmentToAnfGraph(
356   const AnfNodePtrList &lst) {
357   if (lst.empty()) {
358     MS_LOG(EXCEPTION) << "Input anf node list is empty";
359   }
360   FuncGraphPtr fg = nullptr;
361   {
362     // limit the lifetime of guard.
363     MS_EXCEPTION_IF_NULL(lst[0]);
364     MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>());
365     MS_EXCEPTION_IF_NULL(lst[0]->cast<CNodePtr>()->func_graph());
366     TraceGuard guard(std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
367     fg = std::make_shared<FuncGraph>();
368   }
369   AnfNodePtrList inputs;
370   mindspore::HashMap<AnfNodePtr, AnfNodePtr> eqv;
371   // Merge CNodes into a AnfGraph that represents a linear instruction segment
372   for (auto n : lst) {
373     MS_EXCEPTION_IF_NULL(n);
374     if (!n->isa<CNode>()) {
375       MS_LOG(EXCEPTION) << "Inst is not CNode";
376     }
377     auto &inps = n->cast<CNodePtr>()->inputs();
378     if (inps.empty()) {
379       MS_LOG(EXCEPTION) << "Input is empty";
380     }
381     if (!IsValueNode<Primitive>(inps[0]) &&
382         !(IsValueNode<FuncGraph>(inps[0]) &&
383           inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
384       MS_LOG(EXCEPTION) << "Input[0] must be a Primitive ValueNode";
385     }
386     auto fn = inps[0];
387     std::vector<AnfNodePtr> args{fn};
388     if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= kDependInputSize &&
389         eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
390       args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
391       const size_t value_start_index = 2;
392       for (size_t i = value_start_index; i < inps.size(); ++i) {
393         args.emplace_back(NewValueNode(MakeValue(0)));
394       }
395     } else {
396       (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
397                            [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
398     }
399     TraceGuard tg(std::make_shared<TraceSegmentTransform>(n->debug_info()));
400     MS_EXCEPTION_IF_NULL(fg);
401     eqv[n] = fg->NewCNode(args);
402     eqv[n]->set_abstract(n->abstract());
403     eqv[n]->set_kernel_info(n->kernel_info_ptr());
404   }
405   mindspore::HashSet<AnfNodePtr> eqv_keys;
406   for (auto &e : eqv) {
407     (void)eqv_keys.emplace(e.first);
408   }
409   auto mgr = lst[0]->func_graph()->manager();
410   MS_EXCEPTION_IF_NULL(mgr);
411   auto outputs = GetOutput(lst, mgr->node_users(), eqv_keys);
412   AnfNodePtr fg_output;
413   if (outputs.size() > 1) {
414     std::vector<AnfNodePtr> output_args;
415     output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
416     (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
417                          [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
418     // Set output for AnfGraph
419     fg_output = fg->NewCNode(output_args);
420   } else {
421     if (outputs.empty()) {
422       MS_LOG(EXCEPTION) << "Output is empty.";
423     }
424     fg_output = eqv[outputs[0]];
425   }
426   fg->set_output(fg_output);
427   return std::make_tuple(fg, inputs, outputs);
428 }
429 
GetOutput(const AnfNodePtrList & nodes,const NodeUsersMap & users,const mindspore::HashSet<AnfNodePtr> & seen)430 AnfNodePtrList FuncGraphUtils::GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
431                                          const mindspore::HashSet<AnfNodePtr> &seen) {
432   AnfNodePtrList output;
433   if (users.size() == 0) {
434     return output;
435   }
436   for (auto &node : nodes) {
437     MS_EXCEPTION_IF_NULL(node);
438     if (!node->isa<CNode>()) {
439       continue;
440     }
441     auto iter = users.find(node);
442     if (iter == users.end()) {
443       continue;
444     }
445     auto &node_users = iter->second;
446     const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
447                                             [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
448                                               const bool is_outer_user = (seen.find(u.first) == seen.end());
449                                               return is_outer_user;
450                                             });
451     if (has_outer_user) {
452       output.emplace_back(node);
453     }
454   }
455   return output;
456 }
457 
RefSubGraphNode(const FuncGraphPtr & fg,const AnfNodePtr & node,AnfNodePtrList * inputs_ptr,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * eqv_ptr)458 AnfNodePtr FuncGraphUtils::RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
459                                            mindspore::HashMap<AnfNodePtr, AnfNodePtr> *eqv_ptr) {
460   MS_EXCEPTION_IF_NULL(fg);
461   MS_EXCEPTION_IF_NULL(inputs_ptr);
462   MS_EXCEPTION_IF_NULL(eqv_ptr);
463   MS_EXCEPTION_IF_NULL(node);
464   auto &inputs = *inputs_ptr;
465   auto &eqv = *eqv_ptr;
466   if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
467     eqv[node] = node;
468   } else if (eqv.find(node) == eqv.end()) {
469     inputs.push_back(node);
470     eqv[node] = fg->add_parameter();
471     eqv[node]->set_abstract(node->abstract());
472     eqv[node]->set_kernel_info(node->kernel_info_ptr());
473   }
474   return eqv[node];
475 }
476 }  // namespace mindspore
477