1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/graph_builder.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <tuple>
21 #include <set>
22 #include <utility>
23 #include <vector>
24
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "ir/func_graph.h"
27 #include "include/common/utils/utils.h"
28 #include "utils/anf_utils.h"
29 #include "utils/ordered_set.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
31 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
32 #include "backend/common/graph_kernel/graph_kernel_flags.h"
33 #include "ir/func_graph_cloner.h"
34 #include "backend/common/graph_kernel/core/value_depend_op_utils.h"
35 #include "include/backend/anf_runtime_algorithm.h"
36 #include "kernel/common_utils.h"
37
38 namespace mindspore::graphkernel {
39 // find outputs of nodes
FindOutputs(const AnfNodePtrList & nodes,const AnfNodePtrToAnfNodePtrMap & eqv)40 AnfNodePtrList FindOutputs(const AnfNodePtrList &nodes, const AnfNodePtrToAnfNodePtrMap &eqv) {
41 AnfNodePtrList output;
42 auto mng = nodes[0]->func_graph()->manager();
43 MS_EXCEPTION_IF_NULL(mng);
44 auto &users = mng->node_users();
45 for (auto &node : nodes) {
46 // only CNode can be an output.
47 if (!node->isa<CNode>()) {
48 continue;
49 }
50 auto iter = users.find(node);
51 if (iter == users.end()) {
52 continue;
53 }
54 auto &node_users = iter->second;
55 // if any user of the `node` is not in the nodes list, the `node` is an output.
56 if (std::any_of(std::begin(node_users), std::end(node_users),
57 [&eqv](const std::pair<AnfNodePtr, int> &u) { return eqv.find(u.first) == eqv.end(); })) {
58 (void)output.emplace_back(node);
59 }
60 }
61 return output;
62 }
63
RefSubGraphNode(const FuncGraphPtr & fg,const AnfNodePtr & node,AnfNodePtrList * inputs_ptr,AnfNodePtrToAnfNodePtrMap * eqv_ptr)64 AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
65 AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
66 auto &eqv = *eqv_ptr;
67 if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
68 eqv[node] = node;
69 } else if (eqv.find(node) == eqv.end()) {
70 inputs_ptr->push_back(node);
71 eqv[node] = fg->add_parameter();
72 eqv[node]->set_abstract(node->abstract());
73 eqv[node]->set_kernel_info(node->kernel_info_ptr());
74 }
75 return eqv[node];
76 }
77
InlineInnerFuncGraph(const FuncGraphPtr & fg)78 bool InlineInnerFuncGraph(const FuncGraphPtr &fg) {
79 auto mng = fg->manager();
80 MS_EXCEPTION_IF_NULL(mng);
81 bool changed = false;
82 auto cnodes = fg->GetOrderedCnodes();
83 for (const auto &n : cnodes) {
84 auto graph_kernel_g = GetCNodeFuncGraph(n);
85 if (graph_kernel_g == nullptr) {
86 continue;
87 }
88 AnfNodePtrList inp(n->inputs().begin() + 1, n->inputs().end());
89 auto out = InlineClone(graph_kernel_g, fg, inp, n);
90 (void)mng->Replace(n, out);
91 changed = true;
92 }
93 return changed;
94 }
95
EliminateTupleOfTuple(const FuncGraphPtr & fg)96 void EliminateTupleOfTuple(const FuncGraphPtr &fg) {
97 if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
98 return;
99 }
100 auto out_cnode = fg->output()->cast<CNodePtr>();
101 MS_EXCEPTION_IF_NULL(out_cnode);
102 AnfNodePtrList new_args = GkUtils::SpreadTuples(out_cnode->inputs());
103 if (new_args.size() != out_cnode->size()) {
104 auto new_out = fg->NewCNode(new_args);
105 auto mng = fg->manager();
106 MS_EXCEPTION_IF_NULL(mng);
107 (void)mng->Replace(out_cnode, new_out);
108 }
109 AbstractBasePtrList abs_list;
110 (void)std::transform(new_args.begin() + 1, new_args.end(), std::back_inserter(abs_list),
111 [](const AnfNodePtr &node) { return node->abstract(); });
112 fg->output()->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
113 }
114
115 template <typename T>
IsFinite(T value)116 bool IsFinite(T value) {
117 return !(std::isinf(value) || std::isnan(value));
118 }
119
IsFiniteScalar(void * data,TypeId type_id)120 bool IsFiniteScalar(void *data, TypeId type_id) {
121 MS_EXCEPTION_IF_NULL(data);
122 // check if float value is inf or nan
123 if (type_id == kNumberTypeFloat64) {
124 auto value = static_cast<double *>(data)[0];
125 return IsFinite(value);
126 } else if (type_id == kNumberTypeFloat32) {
127 auto value = static_cast<float *>(data)[0];
128 return IsFinite(value);
129 } else if (type_id == kNumberTypeFloat16) {
130 float16 *val = static_cast<float16 *>(data);
131 auto value = static_cast<float>(val[0]);
132 return IsFinite(value);
133 }
134 return true;
135 }
136
UpdateBuildInfoOutputKernelObjectType(const AnfNodePtr & node)137 void UpdateBuildInfoOutputKernelObjectType(const AnfNodePtr &node) {
138 if (node->kernel_info() == nullptr) {
139 return;
140 }
141 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
142 if (build_info != nullptr && build_info->GetAllOutputKernelObjectTypes().empty()) {
143 auto abs_type = AnfAlgo::GetAbstractObjectType(node->abstract());
144 auto object_type = kernel::TypeIdToKernelObjectType(abs_type);
145 build_info->SetOutputsKernelObjectType(std::vector<kernel::KernelObjectType>{object_type});
146 }
147 }
148
ConvertTensorToParameter(const FuncGraphPtr & fg,AnfNodePtrList * inputs_ptr)149 bool ConvertTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
150 auto cnodes = fg->GetOrderedCnodes();
151 mindspore::OrderedSet<AnfNodePtr> value_nodes;
152 for (const auto &cnode : cnodes) {
153 auto &inputs = cnode->inputs();
154 for (size_t i = 1; i < inputs.size(); ++i) {
155 const auto &tnode = inputs[i];
156 auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
157 if (tensor == nullptr) {
158 continue;
159 }
160 auto primitive = GetCNodePrimitive(cnode);
161 // For some primitives, the value in valuenode is required for further optimization.
162 if (ValueDependOpUtils::KeepValueNode(primitive->name(), i - 1)) {
163 continue;
164 }
165 auto type_id = tensor->data_type();
166 // data is nullptr means uninitialized.
167 if (tensor->data().const_data() == nullptr || tensor->DataSize() > 1 ||
168 !IsFiniteScalar(tensor->data_c(), type_id) ||
169 (type_id == kNumberTypeBool && GraphKernelFlags::GetInstance().kernel_generator == "DVM")) {
170 (void)value_nodes.insert(tnode);
171 }
172 }
173 }
174 if (value_nodes.empty()) {
175 return false;
176 }
177 auto mng = fg->manager();
178 if (mng == nullptr) {
179 mng = Manage(fg, false);
180 fg->set_manager(mng);
181 }
182 for (const auto &vnode : value_nodes) {
183 auto parameter = fg->add_parameter();
184 parameter->set_abstract(vnode->abstract());
185 parameter->set_kernel_info(vnode->kernel_info_ptr());
186 UpdateBuildInfoOutputKernelObjectType(parameter);
187 (void)mng->Replace(vnode, parameter);
188 inputs_ptr->push_back(vnode);
189 }
190 return true;
191 }
192
SortParameters(const FuncGraphPtr & fg,AnfNodePtrList * inputs_ptr)193 bool SortParameters(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
194 auto params = fg->parameters();
195 if (params.size() != inputs_ptr->size()) {
196 MS_LOG(EXCEPTION) << "parameters and inputs should have same size, but got " << params.size() << " and "
197 << inputs_ptr->size();
198 }
199 size_t n = inputs_ptr->size();
200 using PairType = std::pair<AnfNodePtr, AnfNodePtr>;
201 std::vector<PairType> normal_pairs;
202 std::vector<PairType> monad_pairs;
203 for (size_t i = 0; i < n; ++i) {
204 if (HasAbstractMonad((*inputs_ptr)[i])) {
205 (void)monad_pairs.emplace_back(params[i], (*inputs_ptr)[i]);
206 } else {
207 (void)normal_pairs.emplace_back(params[i], (*inputs_ptr)[i]);
208 }
209 }
210 if (normal_pairs.empty() || monad_pairs.empty()) {
211 return false;
212 }
213 auto normal_pairs_size = normal_pairs.size();
214 for (size_t i = 0; i < normal_pairs_size; ++i) {
215 params[i] = normal_pairs[i].first;
216 (*inputs_ptr)[i] = normal_pairs[i].second;
217 }
218 for (size_t i = 0; i < monad_pairs.size(); ++i) {
219 params[normal_pairs_size + i] = monad_pairs[i].first;
220 (*inputs_ptr)[normal_pairs_size + i] = monad_pairs[i].second;
221 }
222 fg->set_parameters(std::move(params));
223 return true;
224 }
225
IsTupleOutput(const AnfNodePtr & out,AnfNodePtrList * real_outs)226 bool IsTupleOutput(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
227 if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
228 auto &inputs = out->cast<CNodePtr>()->inputs();
229 real_outs->assign(inputs.begin() + 1, inputs.end());
230 return true;
231 }
232 if (auto fg = GetCNodeFuncGraph(out); fg != nullptr) {
233 return IsTupleOutput(fg->output(), real_outs);
234 }
235 return false;
236 }
237
ReplaceNewFuseCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & new_fuse_cnode,const AnfNodePtrList & outputs)238 void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
239 const AnfNodePtrList &outputs) {
240 MS_EXCEPTION_IF_NULL(func_graph);
241 auto mng = func_graph->manager();
242 MS_EXCEPTION_IF_NULL(mng);
243 // single out
244 if (outputs.size() == 1) {
245 (void)mng->Replace(outputs[0], new_fuse_cnode);
246 return;
247 }
248
249 size_t offset = 0;
250 for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
251 AnfNodePtrList real_outs;
252 // the output is a single tensor
253 if (!IsTupleOutput(outputs[out_idx], &real_outs)) {
254 auto gt_idx = MakeValue(SizeToLong(out_idx + offset));
255 AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
256 gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
257 auto new_out = func_graph->NewCNode(gt_inputs);
258 new_out->set_abstract(outputs[out_idx]->abstract());
259 (void)mng->Replace(outputs[out_idx], new_out);
260 continue;
261 }
262
263 // the out is make tuple , modify the get_item node's value
264 auto users = mng->node_users()[outputs[out_idx]]; // use a copy, the original user map is changed in for-loop.
265 for (auto &user : users) {
266 auto getitem_node = user.first;
267 if (!getitem_node->isa<CNode>() || !IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
268 continue;
269 }
270 auto value_ptr = GetValueNode(getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
271 MS_EXCEPTION_IF_NULL(value_ptr);
272 auto old_gt_idx = GetValue<int64_t>(value_ptr);
273 auto gt_idx = MakeValue(SizeToLong(out_idx + offset) + old_gt_idx);
274 AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
275 gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
276 auto new_getitem_node = func_graph->NewCNode(gt_inputs);
277 new_getitem_node->set_abstract(getitem_node->abstract());
278 (void)mng->Replace(getitem_node, new_getitem_node);
279 }
280
281 offset += real_outs.size() - 1;
282 }
283 }
284
285 // remove parameter which is not used
EliminateRedundantParameters(const FuncGraphPtr & func_graph,AnfNodePtrList * inputs)286 void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
287 MS_EXCEPTION_IF_NULL(inputs);
288 const auto &ori_parameter = func_graph->parameters();
289 auto todos = TopoSort(func_graph->get_return());
290 std::set<AnfNodePtr> used_param;
291 for (auto node : todos) {
292 if (node->isa<Parameter>()) {
293 (void)used_param.insert(node);
294 }
295 }
296 if (used_param.size() == ori_parameter.size()) {
297 return;
298 }
299 AnfNodePtrList new_parameter;
300 AnfNodePtrList new_inputs{(*inputs)[0]};
301 for (size_t i = 0; i < ori_parameter.size(); ++i) {
302 if (used_param.count(ori_parameter[i]) > 0) {
303 new_parameter.push_back(ori_parameter[i]);
304 new_inputs.push_back((*inputs)[i + 1]);
305 }
306 }
307 func_graph->set_parameters(new_parameter);
308 *inputs = std::move(new_inputs);
309 }
310
BuildGraphFromNodes(const AnfNodePtrList & nodes,const ClusterConfig & config)311 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes,
312 const ClusterConfig &config) {
313 FuncGraphPtr fg = nullptr;
314 {
315 // limit the lifetime of guard.
316 TraceGuard guard(std::make_shared<TraceSegmentTransform>(nodes[0]->cast<CNodePtr>()->func_graph()->debug_info()));
317 fg = std::make_shared<FuncGraph>();
318 }
319 AnfNodePtrList input_list;
320 AnfNodePtrToAnfNodePtrMap eqv;
321 // Merge CNodes into a AnfGraph that represents a linear instruction segment
322 for (auto &node : nodes) {
323 auto &node_inputs = node->cast<CNodePtr>()->inputs();
324 std::vector<AnfNodePtr> new_args{node_inputs[0]};
325 (void)std::transform(
326 std::begin(node_inputs) + 1, std::end(node_inputs), std::back_inserter(new_args),
327 [&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); });
328 TraceGuard tg(std::make_shared<TraceSegmentTransform>(node->debug_info()));
329 eqv[node] = fg->NewCNode(new_args);
330 eqv[node]->cast<CNodePtr>()->CloneCNodeInfo(node->cast<CNodePtr>());
331 eqv[node]->cast<CNodePtr>()->set_fullname_with_scope(node->fullname_with_scope());
332 }
333 AnfNodePtrList outputs;
334 if (config.only_output_basenode != nullptr) {
335 // Make base node the only output of func_graph, to duplicate the overlapping parts
336 if (eqv.find(config.only_output_basenode) == eqv.end()) {
337 MS_LOG(EXCEPTION) << "Base node is not in the list of nodes: "
338 << config.only_output_basenode->fullname_with_scope();
339 }
340 outputs.push_back(config.only_output_basenode);
341 } else {
342 outputs = FindOutputs(nodes, eqv);
343 }
344 AnfNodePtr fg_output;
345 if (outputs.size() > 1) {
346 std::vector<AnfNodePtr> output_args;
347 output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
348 (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args),
349 [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
350 // Set output for AnfGraph
351 fg_output = fg->NewCNode(output_args);
352 } else {
353 fg_output = eqv[outputs[0]];
354 }
355 fg->set_output(fg_output);
356 return std::make_tuple(fg, input_list, outputs);
357 }
358
359 // Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
BuildSingleGraphFromNodes(const AnfNodePtrList & nodes,const ClusterConfig & config)360 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes,
361 const ClusterConfig &config) {
362 FuncGraphPtr fg;
363 AnfNodePtrList inputs;
364 AnfNodePtrList outputs;
365 std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes, config);
366
367 FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
368 MS_EXCEPTION_IF_NULL(mng);
369
370 if (config.inline_sub_func_graph) {
371 (void)InlineInnerFuncGraph(fg);
372 }
373 // eliminate tuple of tuple, and set Abstract for output MakeTuple
374 EliminateTupleOfTuple(fg);
375 (void)EliminateMaketupleGetitem(fg);
376 (void)ConvertTensorToParameter(fg, &inputs);
377 if (config.sort_parameter) {
378 SortParameters(fg, &inputs);
379 }
380
381 return std::make_tuple(fg, inputs, outputs);
382 }
383
CreateNewFuseCNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & sub_fg,const AnfNodePtrList & inputs)384 CNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) {
385 std::vector<AnfNodePtr> fn_inputs{NewValueNode(sub_fg)};
386 (void)fn_inputs.insert(fn_inputs.end(), inputs.cbegin(), inputs.cend());
387 EliminateRedundantParameters(sub_fg, &fn_inputs);
388 auto fuse_cnode = main_fg->NewCNode(fn_inputs);
389 fuse_cnode->set_abstract(sub_fg->output()->abstract());
390 Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode);
391 return fuse_cnode;
392 }
393
ReplaceNodesWithGraphKernelNode(const AnfNodePtrList & nodes,const FuncGraphPtr & main_graph,const std::string & postfix,const ClusterConfig & config)394 CNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph,
395 const std::string &postfix, const ClusterConfig &config) {
396 auto mng = main_graph->manager();
397 if (mng == nullptr) {
398 mng = Manage(main_graph, true);
399 main_graph->set_manager(mng);
400 }
401 FuncGraphPtr fg;
402 AnfNodePtrList inputs;
403 AnfNodePtrList outputs;
404 std::tie(fg, inputs, outputs) = BuildSingleGraphFromNodes(nodes, config);
405 auto fuse_new_node = CreateNewFuseCNode(main_graph, fg, inputs);
406 ReplaceNewFuseCNode(main_graph, fuse_new_node, outputs);
407 auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", postfix);
408 fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
409 return fuse_new_node;
410 }
411
412 // Eliminate redundant MakeTuple-Getitem edges
EliminateMaketupleGetitem(const FuncGraphPtr & fg)413 bool EliminateMaketupleGetitem(const FuncGraphPtr &fg) {
414 auto nodes = fg->GetOrderedCnodes();
415 auto mng = GkUtils::GetFuncGraphManager(fg);
416 MS_EXCEPTION_IF_NULL(mng);
417 bool changed = false;
418 for (const auto &node : nodes) {
419 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
420 continue;
421 }
422 auto gt = node->cast<CNodePtr>();
423 auto mt = gt->input(kRealInputNodeIndexInTupleGetItem)->cast<CNodePtr>();
424 if (mt == nullptr || !IsPrimitiveCNode(mt, prim::kPrimMakeTuple)) {
425 continue;
426 }
427 auto idx = AnfUtils::GetIntValue(gt->input(kInputNodeOutputIndexInTupleGetItem));
428 (void)mng->Replace(node, mt->input(LongToSize(idx + 1)));
429 changed = true;
430 }
431 return changed;
432 }
433 } // namespace mindspore::graphkernel
434