• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/graph_builder.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <tuple>
21 #include <set>
22 #include <utility>
23 #include <vector>
24 
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "ir/func_graph.h"
27 #include "include/common/utils/utils.h"
28 #include "utils/anf_utils.h"
29 #include "utils/ordered_set.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
31 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
32 #include "backend/common/graph_kernel/graph_kernel_flags.h"
33 #include "ir/func_graph_cloner.h"
34 #include "backend/common/graph_kernel/core/value_depend_op_utils.h"
35 #include "include/backend/anf_runtime_algorithm.h"
36 #include "kernel/common_utils.h"
37 
38 namespace mindspore::graphkernel {
39 // find outputs of nodes
FindOutputs(const AnfNodePtrList & nodes,const AnfNodePtrToAnfNodePtrMap & eqv)40 AnfNodePtrList FindOutputs(const AnfNodePtrList &nodes, const AnfNodePtrToAnfNodePtrMap &eqv) {
41   AnfNodePtrList output;
42   auto mng = nodes[0]->func_graph()->manager();
43   MS_EXCEPTION_IF_NULL(mng);
44   auto &users = mng->node_users();
45   for (auto &node : nodes) {
46     // only CNode can be an output.
47     if (!node->isa<CNode>()) {
48       continue;
49     }
50     auto iter = users.find(node);
51     if (iter == users.end()) {
52       continue;
53     }
54     auto &node_users = iter->second;
55     // if any user of the `node` is not in the nodes list, the `node` is an output.
56     if (std::any_of(std::begin(node_users), std::end(node_users),
57                     [&eqv](const std::pair<AnfNodePtr, int> &u) { return eqv.find(u.first) == eqv.end(); })) {
58       (void)output.emplace_back(node);
59     }
60   }
61   return output;
62 }
63 
RefSubGraphNode(const FuncGraphPtr & fg,const AnfNodePtr & node,AnfNodePtrList * inputs_ptr,AnfNodePtrToAnfNodePtrMap * eqv_ptr)64 AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
65                            AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
66   auto &eqv = *eqv_ptr;
67   if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
68     eqv[node] = node;
69   } else if (eqv.find(node) == eqv.end()) {
70     inputs_ptr->push_back(node);
71     eqv[node] = fg->add_parameter();
72     eqv[node]->set_abstract(node->abstract());
73     eqv[node]->set_kernel_info(node->kernel_info_ptr());
74   }
75   return eqv[node];
76 }
77 
InlineInnerFuncGraph(const FuncGraphPtr & fg)78 bool InlineInnerFuncGraph(const FuncGraphPtr &fg) {
79   auto mng = fg->manager();
80   MS_EXCEPTION_IF_NULL(mng);
81   bool changed = false;
82   auto cnodes = fg->GetOrderedCnodes();
83   for (const auto &n : cnodes) {
84     auto graph_kernel_g = GetCNodeFuncGraph(n);
85     if (graph_kernel_g == nullptr) {
86       continue;
87     }
88     AnfNodePtrList inp(n->inputs().begin() + 1, n->inputs().end());
89     auto out = InlineClone(graph_kernel_g, fg, inp, n);
90     (void)mng->Replace(n, out);
91     changed = true;
92   }
93   return changed;
94 }
95 
EliminateTupleOfTuple(const FuncGraphPtr & fg)96 void EliminateTupleOfTuple(const FuncGraphPtr &fg) {
97   if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
98     return;
99   }
100   auto out_cnode = fg->output()->cast<CNodePtr>();
101   MS_EXCEPTION_IF_NULL(out_cnode);
102   AnfNodePtrList new_args = GkUtils::SpreadTuples(out_cnode->inputs());
103   if (new_args.size() != out_cnode->size()) {
104     auto new_out = fg->NewCNode(new_args);
105     auto mng = fg->manager();
106     MS_EXCEPTION_IF_NULL(mng);
107     (void)mng->Replace(out_cnode, new_out);
108   }
109   AbstractBasePtrList abs_list;
110   (void)std::transform(new_args.begin() + 1, new_args.end(), std::back_inserter(abs_list),
111                        [](const AnfNodePtr &node) { return node->abstract(); });
112   fg->output()->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
113 }
114 
115 template <typename T>
IsFinite(T value)116 bool IsFinite(T value) {
117   return !(std::isinf(value) || std::isnan(value));
118 }
119 
IsFiniteScalar(void * data,TypeId type_id)120 bool IsFiniteScalar(void *data, TypeId type_id) {
121   MS_EXCEPTION_IF_NULL(data);
122   // check if float value is inf or nan
123   if (type_id == kNumberTypeFloat64) {
124     auto value = static_cast<double *>(data)[0];
125     return IsFinite(value);
126   } else if (type_id == kNumberTypeFloat32) {
127     auto value = static_cast<float *>(data)[0];
128     return IsFinite(value);
129   } else if (type_id == kNumberTypeFloat16) {
130     float16 *val = static_cast<float16 *>(data);
131     auto value = static_cast<float>(val[0]);
132     return IsFinite(value);
133   }
134   return true;
135 }
136 
UpdateBuildInfoOutputKernelObjectType(const AnfNodePtr & node)137 void UpdateBuildInfoOutputKernelObjectType(const AnfNodePtr &node) {
138   if (node->kernel_info() == nullptr) {
139     return;
140   }
141   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
142   if (build_info != nullptr && build_info->GetAllOutputKernelObjectTypes().empty()) {
143     auto abs_type = AnfAlgo::GetAbstractObjectType(node->abstract());
144     auto object_type = kernel::TypeIdToKernelObjectType(abs_type);
145     build_info->SetOutputsKernelObjectType(std::vector<kernel::KernelObjectType>{object_type});
146   }
147 }
148 
ConvertTensorToParameter(const FuncGraphPtr & fg,AnfNodePtrList * inputs_ptr)149 bool ConvertTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
150   auto cnodes = fg->GetOrderedCnodes();
151   mindspore::OrderedSet<AnfNodePtr> value_nodes;
152   for (const auto &cnode : cnodes) {
153     auto &inputs = cnode->inputs();
154     for (size_t i = 1; i < inputs.size(); ++i) {
155       const auto &tnode = inputs[i];
156       auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
157       if (tensor == nullptr) {
158         continue;
159       }
160       auto primitive = GetCNodePrimitive(cnode);
161       // For some primitives, the value in valuenode is required for further optimization.
162       if (ValueDependOpUtils::KeepValueNode(primitive->name(), i - 1)) {
163         continue;
164       }
165       auto type_id = tensor->data_type();
166       // data is nullptr means uninitialized.
167       if (tensor->data().const_data() == nullptr || tensor->DataSize() > 1 ||
168           !IsFiniteScalar(tensor->data_c(), type_id) ||
169           (type_id == kNumberTypeBool && GraphKernelFlags::GetInstance().kernel_generator == "DVM")) {
170         (void)value_nodes.insert(tnode);
171       }
172     }
173   }
174   if (value_nodes.empty()) {
175     return false;
176   }
177   auto mng = fg->manager();
178   if (mng == nullptr) {
179     mng = Manage(fg, false);
180     fg->set_manager(mng);
181   }
182   for (const auto &vnode : value_nodes) {
183     auto parameter = fg->add_parameter();
184     parameter->set_abstract(vnode->abstract());
185     parameter->set_kernel_info(vnode->kernel_info_ptr());
186     UpdateBuildInfoOutputKernelObjectType(parameter);
187     (void)mng->Replace(vnode, parameter);
188     inputs_ptr->push_back(vnode);
189   }
190   return true;
191 }
192 
SortParameters(const FuncGraphPtr & fg,AnfNodePtrList * inputs_ptr)193 bool SortParameters(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
194   auto params = fg->parameters();
195   if (params.size() != inputs_ptr->size()) {
196     MS_LOG(EXCEPTION) << "parameters and inputs should have same size, but got " << params.size() << " and "
197                       << inputs_ptr->size();
198   }
199   size_t n = inputs_ptr->size();
200   using PairType = std::pair<AnfNodePtr, AnfNodePtr>;
201   std::vector<PairType> normal_pairs;
202   std::vector<PairType> monad_pairs;
203   for (size_t i = 0; i < n; ++i) {
204     if (HasAbstractMonad((*inputs_ptr)[i])) {
205       (void)monad_pairs.emplace_back(params[i], (*inputs_ptr)[i]);
206     } else {
207       (void)normal_pairs.emplace_back(params[i], (*inputs_ptr)[i]);
208     }
209   }
210   if (normal_pairs.empty() || monad_pairs.empty()) {
211     return false;
212   }
213   auto normal_pairs_size = normal_pairs.size();
214   for (size_t i = 0; i < normal_pairs_size; ++i) {
215     params[i] = normal_pairs[i].first;
216     (*inputs_ptr)[i] = normal_pairs[i].second;
217   }
218   for (size_t i = 0; i < monad_pairs.size(); ++i) {
219     params[normal_pairs_size + i] = monad_pairs[i].first;
220     (*inputs_ptr)[normal_pairs_size + i] = monad_pairs[i].second;
221   }
222   fg->set_parameters(std::move(params));
223   return true;
224 }
225 
IsTupleOutput(const AnfNodePtr & out,AnfNodePtrList * real_outs)226 bool IsTupleOutput(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
227   if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
228     auto &inputs = out->cast<CNodePtr>()->inputs();
229     real_outs->assign(inputs.begin() + 1, inputs.end());
230     return true;
231   }
232   if (auto fg = GetCNodeFuncGraph(out); fg != nullptr) {
233     return IsTupleOutput(fg->output(), real_outs);
234   }
235   return false;
236 }
237 
ReplaceNewFuseCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & new_fuse_cnode,const AnfNodePtrList & outputs)238 void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
239                          const AnfNodePtrList &outputs) {
240   MS_EXCEPTION_IF_NULL(func_graph);
241   auto mng = func_graph->manager();
242   MS_EXCEPTION_IF_NULL(mng);
243   // single out
244   if (outputs.size() == 1) {
245     (void)mng->Replace(outputs[0], new_fuse_cnode);
246     return;
247   }
248 
249   size_t offset = 0;
250   for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
251     AnfNodePtrList real_outs;
252     // the output is a single tensor
253     if (!IsTupleOutput(outputs[out_idx], &real_outs)) {
254       auto gt_idx = MakeValue(SizeToLong(out_idx + offset));
255       AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
256       gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
257       auto new_out = func_graph->NewCNode(gt_inputs);
258       new_out->set_abstract(outputs[out_idx]->abstract());
259       (void)mng->Replace(outputs[out_idx], new_out);
260       continue;
261     }
262 
263     // the out is make tuple , modify the get_item node's value
264     auto users = mng->node_users()[outputs[out_idx]];  // use a copy, the original user map is changed in for-loop.
265     for (auto &user : users) {
266       auto getitem_node = user.first;
267       if (!getitem_node->isa<CNode>() || !IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
268         continue;
269       }
270       auto value_ptr = GetValueNode(getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
271       MS_EXCEPTION_IF_NULL(value_ptr);
272       auto old_gt_idx = GetValue<int64_t>(value_ptr);
273       auto gt_idx = MakeValue(SizeToLong(out_idx + offset) + old_gt_idx);
274       AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
275       gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
276       auto new_getitem_node = func_graph->NewCNode(gt_inputs);
277       new_getitem_node->set_abstract(getitem_node->abstract());
278       (void)mng->Replace(getitem_node, new_getitem_node);
279     }
280 
281     offset += real_outs.size() - 1;
282   }
283 }
284 
285 // remove parameter which is not used
EliminateRedundantParameters(const FuncGraphPtr & func_graph,AnfNodePtrList * inputs)286 void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
287   MS_EXCEPTION_IF_NULL(inputs);
288   const auto &ori_parameter = func_graph->parameters();
289   auto todos = TopoSort(func_graph->get_return());
290   std::set<AnfNodePtr> used_param;
291   for (auto node : todos) {
292     if (node->isa<Parameter>()) {
293       (void)used_param.insert(node);
294     }
295   }
296   if (used_param.size() == ori_parameter.size()) {
297     return;
298   }
299   AnfNodePtrList new_parameter;
300   AnfNodePtrList new_inputs{(*inputs)[0]};
301   for (size_t i = 0; i < ori_parameter.size(); ++i) {
302     if (used_param.count(ori_parameter[i]) > 0) {
303       new_parameter.push_back(ori_parameter[i]);
304       new_inputs.push_back((*inputs)[i + 1]);
305     }
306   }
307   func_graph->set_parameters(new_parameter);
308   *inputs = std::move(new_inputs);
309 }
310 
BuildGraphFromNodes(const AnfNodePtrList & nodes,const ClusterConfig & config)311 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes,
312                                                                              const ClusterConfig &config) {
313   FuncGraphPtr fg = nullptr;
314   {
315     // limit the lifetime of guard.
316     TraceGuard guard(std::make_shared<TraceSegmentTransform>(nodes[0]->cast<CNodePtr>()->func_graph()->debug_info()));
317     fg = std::make_shared<FuncGraph>();
318   }
319   AnfNodePtrList input_list;
320   AnfNodePtrToAnfNodePtrMap eqv;
321   // Merge CNodes into a AnfGraph that represents a linear instruction segment
322   for (auto &node : nodes) {
323     auto &node_inputs = node->cast<CNodePtr>()->inputs();
324     std::vector<AnfNodePtr> new_args{node_inputs[0]};
325     (void)std::transform(
326       std::begin(node_inputs) + 1, std::end(node_inputs), std::back_inserter(new_args),
327       [&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); });
328     TraceGuard tg(std::make_shared<TraceSegmentTransform>(node->debug_info()));
329     eqv[node] = fg->NewCNode(new_args);
330     eqv[node]->cast<CNodePtr>()->CloneCNodeInfo(node->cast<CNodePtr>());
331     eqv[node]->cast<CNodePtr>()->set_fullname_with_scope(node->fullname_with_scope());
332   }
333   AnfNodePtrList outputs;
334   if (config.only_output_basenode != nullptr) {
335     // Make base node the only output of func_graph, to duplicate the overlapping parts
336     if (eqv.find(config.only_output_basenode) == eqv.end()) {
337       MS_LOG(EXCEPTION) << "Base node is not in the list of nodes: "
338                         << config.only_output_basenode->fullname_with_scope();
339     }
340     outputs.push_back(config.only_output_basenode);
341   } else {
342     outputs = FindOutputs(nodes, eqv);
343   }
344   AnfNodePtr fg_output;
345   if (outputs.size() > 1) {
346     std::vector<AnfNodePtr> output_args;
347     output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
348     (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
349                          [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
350     // Set output for AnfGraph
351     fg_output = fg->NewCNode(output_args);
352   } else {
353     fg_output = eqv[outputs[0]];
354   }
355   fg->set_output(fg_output);
356   return std::make_tuple(fg, input_list, outputs);
357 }
358 
359 // Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
BuildSingleGraphFromNodes(const AnfNodePtrList & nodes,const ClusterConfig & config)360 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes,
361                                                                                    const ClusterConfig &config) {
362   FuncGraphPtr fg;
363   AnfNodePtrList inputs;
364   AnfNodePtrList outputs;
365   std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes, config);
366 
367   FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
368   MS_EXCEPTION_IF_NULL(mng);
369 
370   if (config.inline_sub_func_graph) {
371     (void)InlineInnerFuncGraph(fg);
372   }
373   // eliminate tuple of tuple, and set Abstract for output MakeTuple
374   EliminateTupleOfTuple(fg);
375   (void)EliminateMaketupleGetitem(fg);
376   (void)ConvertTensorToParameter(fg, &inputs);
377   if (config.sort_parameter) {
378     SortParameters(fg, &inputs);
379   }
380 
381   return std::make_tuple(fg, inputs, outputs);
382 }
383 
CreateNewFuseCNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & sub_fg,const AnfNodePtrList & inputs)384 CNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) {
385   std::vector<AnfNodePtr> fn_inputs{NewValueNode(sub_fg)};
386   (void)fn_inputs.insert(fn_inputs.end(), inputs.cbegin(), inputs.cend());
387   EliminateRedundantParameters(sub_fg, &fn_inputs);
388   auto fuse_cnode = main_fg->NewCNode(fn_inputs);
389   fuse_cnode->set_abstract(sub_fg->output()->abstract());
390   Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode);
391   return fuse_cnode;
392 }
393 
ReplaceNodesWithGraphKernelNode(const AnfNodePtrList & nodes,const FuncGraphPtr & main_graph,const std::string & postfix,const ClusterConfig & config)394 CNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph,
395                                          const std::string &postfix, const ClusterConfig &config) {
396   auto mng = main_graph->manager();
397   if (mng == nullptr) {
398     mng = Manage(main_graph, true);
399     main_graph->set_manager(mng);
400   }
401   FuncGraphPtr fg;
402   AnfNodePtrList inputs;
403   AnfNodePtrList outputs;
404   std::tie(fg, inputs, outputs) = BuildSingleGraphFromNodes(nodes, config);
405   auto fuse_new_node = CreateNewFuseCNode(main_graph, fg, inputs);
406   ReplaceNewFuseCNode(main_graph, fuse_new_node, outputs);
407   auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", postfix);
408   fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
409   return fuse_new_node;
410 }
411 
412 // Eliminate redundant MakeTuple-Getitem edges
EliminateMaketupleGetitem(const FuncGraphPtr & fg)413 bool EliminateMaketupleGetitem(const FuncGraphPtr &fg) {
414   auto nodes = fg->GetOrderedCnodes();
415   auto mng = GkUtils::GetFuncGraphManager(fg);
416   MS_EXCEPTION_IF_NULL(mng);
417   bool changed = false;
418   for (const auto &node : nodes) {
419     if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
420       continue;
421     }
422     auto gt = node->cast<CNodePtr>();
423     auto mt = gt->input(kRealInputNodeIndexInTupleGetItem)->cast<CNodePtr>();
424     if (mt == nullptr || !IsPrimitiveCNode(mt, prim::kPrimMakeTuple)) {
425       continue;
426     }
427     auto idx = AnfUtils::GetIntValue(gt->input(kInputNodeOutputIndexInTupleGetItem));
428     (void)mng->Replace(node, mt->input(LongToSize(idx + 1)));
429     changed = true;
430   }
431   return changed;
432 }
433 }  // namespace mindspore::graphkernel
434