• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/parallel_op_combine.h"
17 
18 #include <vector>
19 #include <string>
20 #include <set>
21 #include <deque>
22 #include <utility>
23 #include <algorithm>
24 #include <unordered_set>
25 #include "include/backend/anf_runtime_algorithm.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "kernel/framework_utils.h"
28 #include "backend/common/graph_kernel/graph_kernel_helper.h"
29 #include "include/backend/kernel_graph.h"
30 #include "utils/anf_utils.h"
31 #include "include/common/utils/utils.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
33 #include "utils/ms_context.h"
34 #include "ops/array_ops.h"
35 #include "backend/common/graph_kernel/adapter/callback_impl.h"
36 
37 namespace mindspore::graphkernel {
38 namespace {
39 constexpr auto kPerm = "perm";
40 constexpr auto kShape = "shape";
41 const int kMinUpdateSize = 2;
GetTransposePerm(const PrimitivePtr & primitive)42 std::vector<int64_t> GetTransposePerm(const PrimitivePtr &primitive) {
43   ValuePtr perm = primitive->GetAttr(kPerm);
44   MS_EXCEPTION_IF_NULL(perm);
45   auto perm_val = perm->cast<ValueTuplePtr>();
46   MS_EXCEPTION_IF_NULL(perm_val);
47   auto perm_val_data = perm_val->value();
48   std::vector<int64_t> perm_int;
49   (void)std::transform(perm_val_data.begin(), perm_val_data.end(), std::back_inserter(perm_int),
50                        [=](const ValuePtr &e) -> int64_t {
51                          if (e->isa<Int64Imm>()) {
52                            return GetValue<int64_t>(e);
53                          } else if (e->isa<Int32Imm>()) {
54                            return GetValue<int>(e);
55                          } else {
56                            MS_LOG(EXCEPTION) << "Perm must be int";
57                            return -1;
58                          }
59                        });
60   return perm_int;
61 }
62 }  // namespace
BranchGroupFinder(const std::string & op_name,FIsSupportedOp fis_supported_op,FAreCompatibleOps fare_compatible_ops)63 BranchGroupFinder::BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op,
64                                      FAreCompatibleOps fare_compatible_ops)
65     : op_name_(op_name), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) {}
66 
GetConsumers(FuncGraphManagerPtr mng,const AnfNodePtr & producer)67 AnfNodeIndexSet BranchGroupFinder::GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer) {
68   AnfNodeIndexSet consumers;
69   auto users = mng->node_users()[producer];
70   for (auto it : users) {
71     auto user = it.first;
72     if (user && user->cast<CNodePtr>() && AnfUtils::IsRealKernel(user) && fis_supported_op_(user)) {
73       consumers.add(CNodeIndexPair(it.first, it.second));
74       (void)children_map_[producer].insert(user);
75     }
76   }
77   return consumers;
78 }
79 
Find(const AnfNodePtr & start_node,const FuncGraphPtr & func_graph)80 std::vector<Group> BranchGroupFinder::Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph) {
81   auto graph_kernel_fg = func_graph == nullptr ? common::AnfAlgo::GetCNodeFuncGraphPtr(start_node) : func_graph;
82   MS_EXCEPTION_IF_NULL(graph_kernel_fg);
83   auto mng = graph_kernel_fg->manager();
84   MS_EXCEPTION_IF_NULL(mng);
85   auto cnode = start_node->cast<CNodePtr>();
86   MS_EXCEPTION_IF_NULL(cnode);
87   std::deque<AnfNodePtr> init_consumer;
88   (void)std::transform(graph_kernel_fg->parameters().begin(), graph_kernel_fg->parameters().end(),
89                        std::back_inserter(init_consumer), [](const AnfNodePtr &global_in) { return global_in; });
90   for (size_t i = 1; i < cnode->size(); ++i) {
91     init_consumer.push_back(cnode->input(i));
92   }
93   while (!init_consumer.empty()) {
94     auto new_node = init_consumer.front();
95     init_consumer.pop_front();
96     auto new_consumer = GetConsumers(mng, new_node);
97     (void)std::transform(new_consumer.begin(), new_consumer.end(), std::back_inserter(init_consumer),
98                          [](const CNodeIndexPair &index_pair) { return index_pair.first; });
99   }
100   for (auto it : children_map_) {
101     if (it.second.size() > 1) {
102       (void)op_roots_.insert(it.first);
103     }
104   }
105   std::vector<Group> groups;
106   for (const auto &root : op_roots_) {
107     size_t ngroups = groups.size();
108     auto childrens = children_map_.at(root);
109     for (auto child : childrens) {
110       auto prim = GetCNodePrimitive(child);
111       if (!prim) {
112         continue;
113       }
114       auto prim_name = prim->name();
115       // Branch should start with target node that specified by `op_name_`
116       if (prim_name != op_name_) {
117         continue;
118       }
119       auto branch = CreateBranch(child);
120       branch.SetDataRoot(root);
121       auto it = std::find_if(groups.begin() + ngroups, groups.end(), [this, &branch](const Group &group) {
122         MS_EXCEPTION_IF_CHECK_FAIL(!group.empty() && !group[0].ops.empty(), "group empty or group[0] empty");
123         auto top_branch = group[0];
124         return (branch.target_op_pos == top_branch.target_op_pos) &&
125                fare_compatible_ops_(branch.GetTargetOp(), top_branch.GetTargetOp());
126       });
127       if (it != groups.end()) {
128         it->push_back(branch);
129       } else {
130         (void)groups.emplace_back();
131         groups.back().push_back(branch);
132       }
133     }
134   }
135   return groups;
136 }
137 
CreateBranch(AnfNodePtr lead_op)138 Branch BranchGroupFinder::CreateBranch(AnfNodePtr lead_op) {
139   AnfNodePtrList ops{lead_op};
140   int root_idx = GetCNodePrimitive(lead_op)->name() == op_name_ ? 0 : -1;
141   auto it = children_map_.find(lead_op);
142   while (it != children_map_.end() && it->second.size() == 1) {
143     auto node = *(it->second).begin();
144     ops.push_back(node);
145     auto prim_name = GetCNodePrimitive(node)->name();
146     if (prim_name == op_name_) {
147       root_idx = static_cast<int>(ops.size());
148     }
149     it = children_map_.find(node);
150   }
151   return Branch(ops, root_idx);
152 }
153 
ParallelOpCombiner(const std::string & op_name,uint64_t min_num_branches,const std::string & layout)154 ParallelOpCombiner::ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout)
155     : op_name_(op_name), min_num_branches_(min_num_branches), layout_(layout) {}
156 
Combine(const AnfNodePtr & root,const FuncGraphPtr & func_graph)157 AnfNodePtr ParallelOpCombiner::Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph) {
158   MS_EXCEPTION_IF_NULL(root);
159   if (func_graph) {
160     main_graph_ = func_graph;
161   } else {
162     main_graph_ = common::AnfAlgo::GetCNodeFuncGraphPtr(root);
163   }
164   MS_EXCEPTION_IF_NULL(main_graph_);
165   auto finder = BranchGroupFinder(
166     op_name_, [&](const AnfNodePtr n) { return IsSupportedOp(n); },
167     [&](const AnfNodePtr a, const AnfNodePtr b) { return CanOpsBeCombined(a, b); });
168   auto groups = finder.Find(root, main_graph_);
169   children_map_ = std::move(finder.children_map_);
170   for (const Group &group : groups) {
171     if (group.size() < min_num_branches_) {
172       MS_LOG(INFO) << "group size = " << group.size() << " < " << min_num_branches_ << ", skip.";
173       continue;
174     }
175     CombineBranches(group);
176   }
177   return combined_;
178 }
179 
CombineBranches(const Group & branches)180 void ParallelOpCombiner::CombineBranches(const Group &branches) {
181   auto combined = MakeCombinedOp(branches);
182   auto it = std::min_element(branches.begin(), branches.end(), [](const Branch &branch_a, const Branch &branch_b) {
183     return branch_a.ops.size() < branch_b.ops.size();
184   });
185   size_t depth = it->ops.size();
186   size_t pos;
187   for (pos = 0; pos < depth; ++pos) {
188     if (static_cast<int>(pos) == it->target_op_pos) {
189       continue;
190     }
191     if (!CheckLevel(branches, pos)) {
192       break;
193     }
194     combined = MakeCombinedAnfNodePtrFromFollowingOps(combined, branches, pos);
195   }
196   if (pos > 0) {
197     UpdateGroupOutput(combined, branches, pos - 1);
198   }
199   combined_ = combined;
200 }
201 
CheckLevel(const Group & branches,size_t depth)202 bool ParallelOpCombiner::CheckLevel(const Group &branches, size_t depth) {
203   auto repr = branches[0].ops[depth];
204   auto repr_prim_name = GetCNodePrimitive(repr)->name();
205   // check if all branches in current depth can be combined
206   for (auto it = branches.begin() + 1; it != branches.end(); it++) {
207     const Branch &branch = *it;
208     auto node = branch.ops[depth];
209     auto prim_name = GetCNodePrimitive(node)->name();
210     if (prim_name != repr_prim_name) {
211       MS_LOG(INFO) << "Prim not compatible!" << prim_name << " vs " << repr_prim_name;
212       return false;
213     }
214     if (unsupported_ops_.find(prim_name) != unsupported_ops_.end()) {
215       MS_LOG(INFO) << "Op " << prim_name << " not supported for combination for now, stop.";
216       return false;
217     }
218     if (!IsArgCompatible(repr, node)) {
219       return false;
220     }
221   }
222   MS_LOG(DEBUG) << "Op " << repr_prim_name << " can be combined at depth " << depth;
223   return true;
224 }
225 
AutoUpdateInfo(const CNodePtr & to_update)226 bool ParallelOpCombiner::AutoUpdateInfo(const CNodePtr &to_update) {
227   if (to_update->size() < kMinUpdateSize) {
228     MS_LOG(ERROR) << "Cannot auto update for " << to_update->fullname_with_scope() << " with input size "
229                   << to_update->size();
230     return false;
231   }
232 #ifndef MSLITE_ENABLE_GRAPH_KERNEL
233   Callback::Instance()->ResetKernelInfo(to_update);
234 #else
235   auto rep_input = to_update->input(1);
236   // NOTE: We assume the inputs' formats and types are consistent with outputs'.
237   std::string input_format = Callback::Instance()->GetTargetFromContext() == kAscendDevice ? "" : kOpFormat_NCHW;
238   auto GetPrevOutFormat = [&input_format](const CNodePtr &cnode) -> bool {
239     if (cnode == nullptr || !cnode->HasAttr(kOutputsFormat)) {
240       return false;
241     }
242     auto prev_of = GetValue<std::vector<std::string> >(cnode->GetAttr(kOutputsFormat));
243     if (prev_of.size() > 0) {
244       input_format = prev_of[0];
245       return true;
246     }
247     return false;
248   };
249   if (AnfUtils::IsRealKernel(rep_input)) {
250     (void)GetPrevOutFormat(rep_input->cast<CNodePtr>());
251   }
252   if (input_format.empty()) {
253     auto it = children_map_.find(rep_input);
254     if (it != children_map_.end()) {
255       for (auto orig_user : it->second) {
256         if (GetPrevOutFormat(orig_user->cast<CNodePtr>())) {
257           break;
258         }
259       }
260     }
261   }
262   if (input_format.empty()) {
263     MS_LOG(WARNING) << "Cannot find prev node's input format, use " << layout_
264                     << " by default and that may cause error.";
265     input_format = layout_;
266   }
267   std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(to_update), input_format);
268   to_update->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
269 #endif
270   return true;
271 }
272 
GetUniqueInputs(const Group & branches,size_t depth) const273 std::map<size_t, AnfNodePtrList> ParallelOpCombiner::GetUniqueInputs(const Group &branches, size_t depth) const {
274   std::map<size_t, AnfNodePtrList> unique_inputs;
275   AnfNodePtrList parent_in_branch;
276   if (depth >= 1) {
277     (void)std::transform(branches.begin(), branches.end(), std::back_inserter(parent_in_branch),
278                          [&depth](const Branch &br) { return br.ops[depth - 1]; });
279   } else {
280     Branch b1 = branches[0];
281     parent_in_branch.push_back(b1.GetRootData());
282   }
283 
284   for (auto br : branches) {
285     auto op = br.ops[depth];
286     auto cnode = op->cast<CNodePtr>();
287     // Here we can know for sure that op's arg length are the same (check before)
288     for (size_t i = 1; i < cnode->size(); ++i) {
289       auto in = cnode->input(i);
290       if (std::any_of(parent_in_branch.begin(), parent_in_branch.end(),
291                       [&in](const AnfNodePtr &p) { return in == p; })) {
292         continue;
293       }
294       unique_inputs[i].push_back(in);
295     }
296   }
297   return unique_inputs;
298 }
299 
NewConcatNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & input_node,size_t concat_dim,size_t input_num)300 CNodePtr GraphBuilder::NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node,
301                                      size_t concat_dim, size_t input_num) {
302   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
303   if (Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
304     auto maketuple = NewTupleNode(func_graph, input_node);
305     concat_inputs.push_back(maketuple);
306   } else {
307     for (size_t i = 0; i < input_node.size(); ++i) {
308       auto n = input_node[i];
309       concat_inputs.push_back(n);
310     }
311   }
312   auto concat = func_graph->NewCNode(concat_inputs);
313   MS_EXCEPTION_IF_NULL(concat);
314   func_graph->AddNode(concat);
315   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node[0], 0)};
316   auto shape = common::AnfAlgo::GetOutputInferShape(input_node[0], 0);
317   shape[concat_dim] *= SizeToLong(input_num);
318   std::vector<ShapeVector> shapes(1, shape);
319   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
320   common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(concat_dim)), concat);
321   common::AnfAlgo::SetNodeAttr(kAttrN, MakeValue(static_cast<int64_t>(input_num)), concat);
322   return concat;
323 }
324 
NewTupleNode(const FuncGraphPtr & func_graph,AnfNodePtrList shared_inputs)325 CNodePtr GraphBuilder::NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs) {
326   auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
327   AbstractBasePtrList abs_list;
328   for (auto in : shared_inputs) {
329     mk_inputs.push_back(in);
330     abs_list.push_back(in->abstract());
331   }
332   auto make_tuple_node = func_graph->NewCNode(mk_inputs);
333   func_graph->AddNode(make_tuple_node);
334   make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
335   return make_tuple_node;
336 }
337 
NewSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,size_t split_dim,size_t split_num)338 CNodePtr GraphBuilder::NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
339                                     size_t split_num) {
340   if (split_num == 0) {
341     MS_LOG(EXCEPTION) << "split_num should not be zero.";
342   }
343   MS_EXCEPTION_IF_NULL(input_node);
344   std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
345                                           input_node};
346   auto split = func_graph->NewCNode(split_inputs);
347   func_graph->AddNode(split);
348   MS_EXCEPTION_IF_NULL(split);
349   auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
350   std::vector<TypeId> dtypes(split_num, dtype);
351   auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
352   shape[split_dim] /= SizeToLong(split_num);
353   std::vector<ShapeVector> shapes(split_num, shape);
354   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
355   common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(split_dim), split);
356   common::AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue<int64_t>(split_num), split);
357   return split;
358 }
359 
NewElemwiseNoAttrNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs)360 CNodePtr GraphBuilder::NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
361   auto node = func_graph->NewCNode(inputs);
362   func_graph->AddNode(node);
363   MS_EXCEPTION_IF_NULL(node);
364   MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
365   MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
366   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
367   std::vector<ShapeVector> shapes = {common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0)};
368   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
369   return node;
370 }
371 
NewReshapeNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs,const AnfNodePtr & orig_node)372 CNodePtr GraphBuilder::NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
373                                       const AnfNodePtr &orig_node) {
374   auto node = func_graph->NewCNode(inputs);
375   func_graph->AddNode(node);
376   MS_EXCEPTION_IF_NULL(node);
377   MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
378   MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
379   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
380   auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
381   auto orig_shape_in = common::AnfAlgo::GetPrevNodeOutputInferShape(orig_node, 0);
382   auto orig_shape_out = common::AnfAlgo::GetOutputInferShape(orig_node, 0);
383   auto new_out_shape = InferReshapeOut(orig_shape_in, orig_shape_out, new_shape_in);
384   GetCNodePrimitive(node)->set_attr(kShape, MakeValue(new_out_shape));
385   std::vector<ShapeVector> shapes = {new_out_shape};
386   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
387   return node;
388 }
389 
NewTransposeNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs)390 CNodePtr GraphBuilder::NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) {
391   auto node = func_graph->NewCNode(inputs);
392   func_graph->AddNode(node);
393   MS_EXCEPTION_IF_NULL(node);
394   MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
395   MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
396   std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
397   auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
398   auto perm_int = GetTransposePerm(GetCNodePrimitive(node));
399   auto new_out_shape = InferTransposeOut(new_shape_in, perm_int);
400   std::vector<ShapeVector> shapes = {new_out_shape};
401   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
402   return node;
403 }
404 
InferReshapeOut(const ShapeVector & orig_reshape_in,const ShapeVector & orig_reshape_out,const ShapeVector & new_reshape_in)405 ShapeVector GraphBuilder::InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
406                                           const ShapeVector &new_reshape_in) {
407   ShapeVector new_shape_out;
408   if (orig_reshape_in.size() == new_reshape_in.size()) {
409     return InferConcatReshapeOut(orig_reshape_in, orig_reshape_out, new_reshape_in);
410   } else {
411     MS_LOG(EXCEPTION) << "Stack combiner infer for reshape not impl yet";
412   }
413   return new_shape_out;
414 }
415 
InferTransposeOut(const ShapeVector & in_shape,const std::vector<int64_t> & perm)416 ShapeVector GraphBuilder::InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm) {
417   ShapeVector out_shape;
418   for (int64_t i : perm) {
419     auto idx = LongToSize(i);
420     out_shape.push_back(in_shape[idx]);
421   }
422   return out_shape;
423 }
424 }  // namespace mindspore::graphkernel
425